1
0
mirror of https://github.com/labstack/echo.git synced 2025-07-17 01:43:02 +02:00

V5.0.0-alpha

This commit is contained in:
toimtoimtoim
2021-07-15 23:34:01 +03:00
parent d5f883707b
commit 829ddef710
93 changed files with 9738 additions and 7363 deletions

View File

@ -19,7 +19,7 @@ on:
- '_fixture/**'
- '.github/**'
- 'codecov.yml'
workflow_dispatch:
workflow_dispatch: # to be able to run workflow manually
jobs:
test:
@ -28,7 +28,8 @@ jobs:
os: [ubuntu-latest, macos-latest, windows-latest]
# Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
# Echo tests with last four major releases
go: [1.16, 1.17, 1.18]
# except v5 starts from 1.17 until there is last four major releases after that
go: [1.17, 1.18]
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
@ -52,7 +53,7 @@ jobs:
- name: Upload coverage to Codecov
if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest'
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v2
with:
token:
fail_ci_if_error: false

View File

@ -1,21 +0,0 @@
arch:
- amd64
- ppc64le
language: go
go:
- 1.14.x
- 1.15.x
- tip
env:
- GO111MODULE=on
install:
- go get -v golang.org/x/lint/golint
script:
- golint -set_exit_status ./...
- go test -race -coverprofile=coverage.txt -covermode=atomic ./...
after_success:
- bash <(curl -s https://codecov.io/bash)
matrix:
allow_failures:
- go: tip

View File

@ -1,6 +1,6 @@
The MIT License (MIT)
Copyright (c) 2021 LabStack
Copyright (c) 2022 LabStack
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@ -24,11 +24,11 @@ race: ## Run tests with data race detector
@go test -race ${PKG_LIST}
benchmark: ## Run benchmarks
@go test -run="-" -bench=".*" ${PKG_LIST}
@go test -run="-" -benchmem -bench=".*" ${PKG_LIST}
help: ## Display this help screen
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
goversion ?= "1.16"
test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.16
goversion ?= "1.17"
test_version: ## Run tests inside Docker with given version (defaults to 1.17 oldest supported). Example: make test_version goversion=1.17
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"

View File

@ -11,12 +11,14 @@
## Supported Go versions
Echo supports last four major releases. `v5` starts from 1.16 until there is last four major releases after that.
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
Therefore a Go version capable of understanding /vN suffixed imports is required:
- 1.9.7+
- 1.10.3+
- 1.14+
- 1.16+
Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended
way of using Echo going forward.
@ -39,24 +41,13 @@ For older versions, please use the latest v3 tag.
- Automatic TLS via Let’s Encrypt
- HTTP/2 support
## Benchmarks
Date: 2020/11/11<br>
Source: https://github.com/vishr/web-framework-benchmark<br>
Lower is better!
<img src="https://i.imgur.com/qwPNQbl.png">
<img src="https://i.imgur.com/s8yKQjx.png">
The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
## [Guide](https://echo.labstack.com/guide)
### Installation
```sh
// go get github.com/labstack/echo/{version}
go get github.com/labstack/echo/v4
go get github.com/labstack/echo/v5
```
### Example
@ -65,8 +56,8 @@ go get github.com/labstack/echo/v4
package main
import (
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"net/http"
)
@ -82,7 +73,9 @@ func main() {
e.GET("/", hello)
// Start server
e.Logger.Fatal(e.Start(":1323"))
if err := e.Start(":1323"); err != http.ErrServerClosed {
log.Fatal(err)
}
}
// Handler
@ -94,14 +87,14 @@ func hello(c echo.Context) error {
# Third-party middlewares
| Repository | Description |
|------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares |
| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator |
| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. |
| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. |
| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo.
| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
Please send a PR to add your own library here.

84
bind.go
View File

@ -11,42 +11,38 @@ import (
"strings"
)
type (
// Binder is the interface that wraps the Bind method.
Binder interface {
Bind(i interface{}, c Context) error
type Binder interface {
Bind(c Context, i interface{}) error
}
// DefaultBinder is the default implementation of the Binder interface.
DefaultBinder struct{}
type DefaultBinder struct{}
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
// Types that don't implement this, but do implement encoding.TextUnmarshaler
// will use that interface instead.
BindUnmarshaler interface {
type BindUnmarshaler interface {
// UnmarshalParam decodes and assigns a value from an form or query param.
UnmarshalParam(param string) error
}
)
// BindPathParams binds path params to bindable object
func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error {
names := c.ParamNames()
values := c.ParamValues()
func BindPathParams(c Context, i interface{}) error {
params := map[string][]string{}
for i, name := range names {
params[name] = []string{values[i]}
for _, param := range c.PathParams() {
params[param.Name] = []string{param.Value}
}
if err := b.bindData(i, params, "param"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
if err := bindData(i, params, "param"); err != nil {
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
return nil
}
// BindQueryParams binds query params to bindable object
func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error {
if err := b.bindData(i, c.QueryParams(), "query"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
func BindQueryParams(c Context, i interface{}) error {
if err := bindData(i, c.QueryParams(), "query"); err != nil {
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
return nil
}
@ -56,7 +52,7 @@ func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error {
// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm
// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm
func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
func BindBody(c Context, i interface{}) (err error) {
req := c.Request()
if req.ContentLength == 0 {
return
@ -70,25 +66,25 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
case *HTTPError:
return err
default:
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
}
case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML):
if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
if ute, ok := err.(*xml.UnsupportedTypeError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error()))
} else if se, ok := err.(*xml.SyntaxError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error()))
}
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
params, err := c.FormParams()
values, err := c.FormValues()
if err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
if err = b.bindData(i, params, "form"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
if err = bindData(i, values, "form"); err != nil {
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
default:
return ErrUnsupportedMediaType
@ -97,18 +93,18 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
}
// BindHeaders binds HTTP headers to a bindable object
func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error {
if err := b.bindData(i, c.Request().Header, "header"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
func BindHeaders(c Context, i interface{}) error {
if err := bindData(i, c.Request().Header, "header"); err != nil {
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
return nil
}
// Bind implements the `Binder#Bind` function.
// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous
// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
if err := b.BindPathParams(c, i); err != nil {
// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
func (b *DefaultBinder) Bind(c Context, i interface{}) (err error) {
if err := BindPathParams(c, i); err != nil {
return err
}
// Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body.
@ -116,15 +112,15 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
// The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670)
method := c.Request().Method
if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead {
if err = b.BindQueryParams(c, i); err != nil {
if err = BindQueryParams(c, i); err != nil {
return err
}
}
return b.BindBody(c, i)
return BindBody(c, i)
}
// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag
func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error {
func bindData(destination interface{}, data map[string][]string, tag string) error {
if destination == nil || len(data) == 0 {
return nil
}
@ -167,10 +163,10 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
}
if inputFieldName == "" {
// If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags).
// structs that implement BindUnmarshaler are binded only when they have explicit tag
// If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags).
// structs that implement BindUnmarshaler are bound only when they have explicit tag
if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil {
if err := bindData(structField.Addr().Interface(), data, tag); err != nil {
return err
}
}
@ -180,10 +176,8 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
inputValue, exists := data[inputFieldName]
if !exists {
// Go json.Unmarshal supports case insensitive binding. However the
// url params are bound case sensitive which is inconsistent. To
// fix this we must check all of the map values in a
// case-insensitive search.
// Go json.Unmarshal supports case-insensitive binding. However, the url params are bound case-sensitive which
// is inconsistent. To fix this we must check all the map values in a case-insensitive search.
for k, v := range data {
if strings.EqualFold(k, inputFieldName) {
inputValue = v
@ -297,7 +291,7 @@ func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
func setIntField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0"
return nil
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err == nil {
@ -308,7 +302,7 @@ func setIntField(value string, bitSize int, field reflect.Value) error {
func setUintField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0"
return nil
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err == nil {
@ -319,7 +313,7 @@ func setUintField(value string, bitSize int, field reflect.Value) error {
func setBoolField(value string, field reflect.Value) error {
if value == "" {
value = "false"
return nil
}
boolVal, err := strconv.ParseBool(value)
if err == nil {
@ -330,7 +324,7 @@ func setBoolField(value string, field reflect.Value) error {
func setFloatField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0.0"
return nil
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err == nil {

View File

@ -277,7 +277,7 @@ func TestBindHeaderParam(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
u := new(user)
err := (&DefaultBinder{}).BindHeaders(c, u)
err := BindHeaders(c, u)
if assert.NoError(t, err) {
assert.Equal(t, 2, u.ID)
assert.Equal(t, "Jon Doe", u.Name)
@ -291,7 +291,7 @@ func TestBindHeaderParamBadType(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
u := new(user)
err := (&DefaultBinder{}).BindHeaders(c, u)
err := BindHeaders(c, u)
assert.Error(t, err)
httpErr, ok := err.(*HTTPError)
@ -300,6 +300,52 @@ func TestBindHeaderParamBadType(t *testing.T) {
}
}
func TestBind_CombineQueryWithHeaderParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/products/999?length=50&page=10&language=et", nil)
req.Header.Set("language", "de")
req.Header.Set("length", "99")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.SetPathParams(PathParams{{
Name: "id",
Value: "999",
}})
type SearchOpts struct {
ID int `param:"id"`
Length int `query:"length"`
Page int `query:"page"`
Search string `query:"search"`
Language string `query:"language" header:"language"`
}
opts := SearchOpts{
Length: 100,
Page: 0,
Search: "default value",
Language: "en",
}
err := c.Bind(&opts)
assert.NoError(t, err)
assert.Equal(t, 50, opts.Length) // bind from query
assert.Equal(t, 10, opts.Page) // bind from query
assert.Equal(t, 999, opts.ID) // bind from path param
assert.Equal(t, "et", opts.Language) // bind from query
assert.Equal(t, "default value", opts.Search) // default value stays
// make sure another bind will not mess already set values unless there are new values
err = BindHeaders(c, &opts)
assert.NoError(t, err)
assert.Equal(t, 50, opts.Length) // does not have tag in struct although header exists
assert.Equal(t, 10, opts.Page)
assert.Equal(t, 999, opts.ID)
assert.Equal(t, "de", opts.Language) // header overwrites now this value
assert.Equal(t, "default value", opts.Search)
}
func TestBindUnmarshalParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
@ -330,7 +376,7 @@ func TestBindUnmarshalParam(t *testing.T) {
func TestBindUnmarshalText(t *testing.T) {
e := New()
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
@ -406,7 +452,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) {
func TestBindUnmarshalTextPtr(t *testing.T) {
e := New()
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil)
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
@ -439,8 +485,7 @@ func TestBindUnsupportedMediaType(t *testing.T) {
func TestBindbindData(t *testing.T) {
a := assert.New(t)
ts := new(bindTestStruct)
b := new(DefaultBinder)
err := b.bindData(ts, values, "form")
err := bindData(ts, values, "form")
a.NoError(err)
a.Equal(0, ts.I)
@ -462,12 +507,15 @@ func TestBindbindData(t *testing.T) {
func TestBindParam(t *testing.T) {
e := New()
req := httptest.NewRequest(GET, "/", nil)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.SetPath("/users/:id/:name")
c.SetParamNames("id", "name")
c.SetParamValues("1", "Jon Snow")
cc := c.(RoutableContext)
cc.SetRouteInfo(routeInfo{path: "/users/:id/:name"})
cc.SetRawPathParams(&PathParams{
{Name: "id", Value: "1"},
{Name: "name", Value: "Jon Snow"},
})
u := new(user)
err := c.Bind(u)
@ -478,9 +526,11 @@ func TestBindParam(t *testing.T) {
// Second test for the absence of a param
c2 := e.NewContext(req, rec)
c2.SetPath("/users/:id")
c2.SetParamNames("id")
c2.SetParamValues("1")
cc2 := c2.(RoutableContext)
cc2.SetRouteInfo(routeInfo{path: "/users/:id"})
cc2.SetRawPathParams(&PathParams{
{Name: "id", Value: "1"},
})
u = new(user)
err = c2.Bind(u)
@ -492,15 +542,17 @@ func TestBindParam(t *testing.T) {
// Bind something with param and post data payload
body := bytes.NewBufferString(`{ "name": "Jon Snow" }`)
e2 := New()
req2 := httptest.NewRequest(POST, "/", body)
req2 := httptest.NewRequest(http.MethodPost, "/", body)
req2.Header.Set(HeaderContentType, MIMEApplicationJSON)
rec2 := httptest.NewRecorder()
c3 := e2.NewContext(req2, rec2)
c3.SetPath("/users/:id")
c3.SetParamNames("id")
c3.SetParamValues("1")
cc3 := c3.(RoutableContext)
cc3.SetRouteInfo(routeInfo{path: "/users/:id"})
cc3.SetRawPathParams(&PathParams{
{Name: "id", Value: "1"},
})
u = new(user)
err = c3.Bind(u)
@ -556,47 +608,115 @@ func TestBindSetWithProperType(t *testing.T) {
assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
}
func TestBindSetFields(t *testing.T) {
assert := assert.New(t)
func TestSetIntField(t *testing.T) {
ts := new(bindTestStruct)
ts.I = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setIntField("", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 100, ts.I)
// second set with value sets the value
err = setIntField("5", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 5, ts.I)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setIntField("", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 5, ts.I)
}
func TestSetUintField(t *testing.T) {
ts := new(bindTestStruct)
ts.UI = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setUintField("", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(100), ts.UI)
// second set with value sets the value
err = setUintField("5", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(5), ts.UI)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setUintField("", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(5), ts.UI)
}
func TestSetFloatField(t *testing.T) {
ts := new(bindTestStruct)
ts.F32 = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setFloatField("", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(100), ts.F32)
// second set with value sets the value
err = setFloatField("15.5", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(15.5), ts.F32)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setFloatField("", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(15.5), ts.F32)
}
func TestSetBoolField(t *testing.T) {
ts := new(bindTestStruct)
ts.B = true
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setBoolField("", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// second set with value sets the value
err = setBoolField("true", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setBoolField("", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// fourth set to false
err = setBoolField("false", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, false, ts.B)
}
func TestUnmarshalFieldNonPtr(t *testing.T) {
ts := new(bindTestStruct)
val := reflect.ValueOf(ts).Elem()
// Int
if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) {
assert.Equal(5, ts.I)
}
if assert.NoError(setIntField("", 0, val.FieldByName("I"))) {
assert.Equal(0, ts.I)
}
// Uint
if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) {
assert.Equal(uint(10), ts.UI)
}
if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) {
assert.Equal(uint(0), ts.UI)
}
// Float
if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) {
assert.Equal(float32(15.5), ts.F32)
}
if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) {
assert.Equal(float32(0.0), ts.F32)
}
// Bool
if assert.NoError(setBoolField("true", val.FieldByName("B"))) {
assert.Equal(true, ts.B)
}
if assert.NoError(setBoolField("", val.FieldByName("B"))) {
assert.Equal(false, ts.B)
}
ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T"))
if assert.NoError(err) {
assert.Equal(ok, true)
assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
if assert.NoError(t, err) {
assert.True(t, ok)
assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
}
}
@ -604,11 +724,10 @@ func BenchmarkBindbindDataWithTags(b *testing.B) {
b.ReportAllocs()
assert := assert.New(b)
ts := new(bindTestStructWithTags)
binder := new(DefaultBinder)
var err error
b.ResetTimer()
for i := 0; i < b.N; i++ {
err = binder.bindData(ts, values, "form")
err = bindData(ts, values, "form")
}
assert.NoError(err)
assertBindTestStruct(assert, (*bindTestStruct)(ts))
@ -840,8 +959,10 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
c := e.NewContext(req, rec)
if !tc.whenNoPathParams {
c.SetParamNames("node")
c.SetParamValues("node_from_path")
cc := c.(RoutableContext)
cc.SetRawPathParams(&PathParams{
{Name: "node", Value: "node_from_path"},
})
}
var bindTarget interface{}
@ -852,7 +973,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
}
b := new(DefaultBinder)
err := b.Bind(bindTarget, c)
err := b.Bind(c, bindTarget)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@ -1021,8 +1142,10 @@ func TestDefaultBinder_BindBody(t *testing.T) {
c := e.NewContext(req, rec)
if !tc.whenNoPathParams {
c.SetParamNames("node")
c.SetParamValues("real_node")
cc := c.(RoutableContext)
cc.SetRawPathParams(&PathParams{
{Name: "node", Value: "real_node"},
})
}
var bindTarget interface{}
@ -1031,9 +1154,8 @@ func TestDefaultBinder_BindBody(t *testing.T) {
} else {
bindTarget = &Node{}
}
b := new(DefaultBinder)
err := b.BindBody(c, bindTarget)
err := BindBody(c, bindTarget)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {

View File

@ -123,10 +123,10 @@ func QueryParamsBinder(c Context) *ValueBinder {
func PathParamsBinder(c Context) *ValueBinder {
return &ValueBinder{
failFast: true,
ValueFunc: c.Param,
ValueFunc: c.PathParam,
ValuesFunc: func(sourceParam string) []string {
// path parameter should not have multiple values so getting values does not make sense but lets not error out here
value := c.Param(sourceParam)
value := c.PathParam(sourceParam)
if value == "" {
return nil
}

View File

@ -4,7 +4,7 @@ package echo_test
import (
"encoding/base64"
"fmt"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"log"
"net/http"
"net/http/httptest"

View File

@ -1,265 +0,0 @@
// +build go1.15
package echo
/**
Since version 1.15 time.Time and time.Duration error message pattern has changed (values are wrapped now in \"\")
So pre 1.15 these tests fail with similar error:
expected: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param"
actual : "code=400, message=failed to bind field value to Duration, internal=time: invalid duration nope, field=param"
*/
import (
"errors"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func createTestContext15(URL string, body io.Reader, pathParams map[string]string) Context {
e := New()
req := httptest.NewRequest(http.MethodGet, URL, body)
if body != nil {
req.Header.Set(HeaderContentType, MIMEApplicationJSON)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if len(pathParams) > 0 {
names := make([]string, 0)
values := make([]string, 0)
for name, value := range pathParams {
names = append(names, name)
values = append(values, value)
}
c.SetParamNames(names...)
c.SetParamValues(values...)
}
return c
}
func TestValueBinder_TimeError(t *testing.T) {
var testCases = []struct {
name string
givenFailFast bool
givenBindErrors []error
whenURL string
whenMust bool
whenLayout string
expectValue time.Time
expectError string
}{
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope&param=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope&param=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext15(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := time.Time{}
var err error
if tc.whenMust {
err = b.MustTime("param", &dest, tc.whenLayout).BindError()
} else {
err = b.Time("param", &dest, tc.whenLayout).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_TimesError(t *testing.T) {
var testCases = []struct {
name string
givenFailFast bool
givenBindErrors []error
whenURL string
whenMust bool
whenLayout string
expectValue []time.Time
expectError string
}{
{
name: "nok, fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1&param=100",
expectValue: []time.Time(nil),
expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope&param=100",
expectValue: []time.Time(nil),
expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope&param=100",
expectValue: []time.Time(nil),
expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext15(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
b.errors = tc.givenBindErrors
layout := time.RFC3339
if tc.whenLayout != "" {
layout = tc.whenLayout
}
var dest []time.Time
var err error
if tc.whenMust {
err = b.MustTimes("param", &dest, layout).BindError()
} else {
err = b.Times("param", &dest, layout).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_DurationError(t *testing.T) {
var testCases = []struct {
name string
givenFailFast bool
givenBindErrors []error
whenURL string
whenMust bool
expectValue time.Duration
expectError string
}{
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope&param=100",
expectValue: 0,
expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope&param=100",
expectValue: 0,
expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext15(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
var dest time.Duration
var err error
if tc.whenMust {
err = b.MustDuration("param", &dest).BindError()
} else {
err = b.Duration("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_DurationsError(t *testing.T) {
var testCases = []struct {
name string
givenFailFast bool
givenBindErrors []error
whenURL string
whenMust bool
expectValue []time.Duration
expectError string
}{
{
name: "nok, fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1&param=100",
expectValue: []time.Duration(nil),
expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope&param=100",
expectValue: []time.Duration(nil),
expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope&param=100",
expectValue: []time.Duration(nil),
expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext15(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
b.errors = tc.givenBindErrors
var dest []time.Duration
var err error
if tc.whenMust {
err = b.MustDurations("param", &dest).BindError()
} else {
err = b.Durations("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@ -26,14 +26,15 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string)
c := e.NewContext(req, rec)
if len(pathParams) > 0 {
names := make([]string, 0)
values := make([]string, 0)
params := make(PathParams, 0)
for name, value := range pathParams {
names = append(names, name)
values = append(values, value)
params = append(params, PathParam{
Name: name,
Value: value,
})
}
c.SetParamNames(names...)
c.SetParamValues(values...)
cc := c.(RoutableContext)
cc.SetRawPathParams(&params)
}
return c
@ -2917,7 +2918,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) {
binder := new(DefaultBinder)
for i := 0; i < b.N; i++ {
var dest Opts
_ = binder.Bind(&dest, c)
_ = binder.Bind(c, &dest)
}
}
@ -2984,7 +2985,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
binder := new(DefaultBinder)
for i := 0; i < b.N; i++ {
var dest Opts
_ = binder.Bind(&dest, c)
_ = binder.Bind(c, &dest)
if dest.Int64 != 1 {
b.Fatalf("int64!=1")
}
@ -3029,3 +3030,224 @@ func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) {
}
}
}
func TestValueBinder_TimeError(t *testing.T) {
var testCases = []struct {
name string
givenFailFast bool
givenBindErrors []error
whenURL string
whenMust bool
whenLayout string
expectValue time.Time
expectError string
}{
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope&param=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope&param=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := time.Time{}
var err error
if tc.whenMust {
err = b.MustTime("param", &dest, tc.whenLayout).BindError()
} else {
err = b.Time("param", &dest, tc.whenLayout).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_TimesError(t *testing.T) {
var testCases = []struct {
name string
givenFailFast bool
givenBindErrors []error
whenURL string
whenMust bool
whenLayout string
expectValue []time.Time
expectError string
}{
{
name: "nok, fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1&param=100",
expectValue: []time.Time(nil),
expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope&param=100",
expectValue: []time.Time(nil),
expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope&param=100",
expectValue: []time.Time(nil),
expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
b.errors = tc.givenBindErrors
layout := time.RFC3339
if tc.whenLayout != "" {
layout = tc.whenLayout
}
var dest []time.Time
var err error
if tc.whenMust {
err = b.MustTimes("param", &dest, layout).BindError()
} else {
err = b.Times("param", &dest, layout).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_DurationError(t *testing.T) {
var testCases = []struct {
name string
givenFailFast bool
givenBindErrors []error
whenURL string
whenMust bool
expectValue time.Duration
expectError string
}{
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope&param=100",
expectValue: 0,
expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope&param=100",
expectValue: 0,
expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
var dest time.Duration
var err error
if tc.whenMust {
err = b.MustDuration("param", &dest).BindError()
} else {
err = b.Duration("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_DurationsError(t *testing.T) {
var testCases = []struct {
name string
givenFailFast bool
givenBindErrors []error
whenURL string
whenMust bool
expectValue []time.Duration
expectError string
}{
{
name: "nok, fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1&param=100",
expectValue: []time.Duration(nil),
expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope&param=100",
expectValue: []time.Duration(nil),
expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope&param=100",
expectValue: []time.Duration(nil),
expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
b.errors = tc.givenBindErrors
var dest []time.Duration
var err error
if tc.whenMust {
err = b.MustDurations("param", &dest).BindError()
} else {
err = b.Durations("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@ -3,20 +3,22 @@ package echo
import (
"bytes"
"encoding/xml"
"errors"
"fmt"
"io"
"io/fs"
"mime/multipart"
"net"
"net/http"
"net/url"
"path/filepath"
"strings"
"sync"
)
type (
// Context represents the context of the current HTTP request. It holds request and
// response objects, path, path parameters, data and registered handler.
Context interface {
type Context interface {
// Request returns `*http.Request`.
Request() *http.Request
@ -43,30 +45,28 @@ type (
// The behavior can be configured using `Echo#IPExtractor`.
RealIP() string
// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route.
// In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases.
RouteInfo() RouteInfo
// Path returns the registered path for the handler.
Path() string
// SetPath sets the registered path for the handler.
SetPath(p string)
// PathParam returns path parameter by name.
PathParam(name string) string
// Param returns path parameter by name.
Param(name string) string
// PathParams returns path parameter values.
PathParams() PathParams
// ParamNames returns path parameter names.
ParamNames() []string
// SetParamNames sets path parameter names.
SetParamNames(names ...string)
// ParamValues returns path parameter values.
ParamValues() []string
// SetParamValues sets path parameter values.
SetParamValues(values ...string)
// SetPathParams sets path parameters for current request.
SetPathParams(params PathParams)
// QueryParam returns the query param for the provided name.
QueryParam(name string) string
// QueryParamDefault returns the query param or default value for the provided name.
QueryParamDefault(name, defaultValue string) string
// QueryParams returns the query parameters as `url.Values`.
QueryParams() url.Values
@ -76,8 +76,11 @@ type (
// FormValue returns the form field value for the provided name.
FormValue(name string) string
// FormParams returns the form parameters as `url.Values`.
FormParams() (url.Values, error)
// FormValueDefault returns the form field value or default value for the provided name.
FormValueDefault(name, defaultValue string) string
// FormValues returns the form field values as `url.Values`.
FormValues() (url.Values, error)
// FormFile returns the multipart form file for the provided name.
FormFile(name string) (*multipart.FileHeader, error)
@ -156,6 +159,9 @@ type (
// File sends a response with the content of the file.
File(file string) error
// FileFS sends a response with the content of the file from given filesystem.
FileFS(file string, filesystem fs.FS) error
// Attachment sends a response as attachment, prompting client to save the
// file.
Attachment(file string, name string) error
@ -169,23 +175,15 @@ type (
// Redirect redirects the request to a provided URL with status code.
Redirect(code int, url string) error
// Error invokes the registered HTTP error handler. Generally used by middleware.
Error(err error)
// Handler returns the matched handler by router.
Handler() HandlerFunc
// SetHandler sets the matched handler by router.
SetHandler(h HandlerFunc)
// Logger returns the `Logger` instance.
Logger() Logger
// Set the logger
SetLogger(l Logger)
// Echo returns the `Echo` instance.
Echo() *Echo
}
// ServableContext is interface that Echo context implementation must implement to be usable in middleware/handlers and
// be able to be routed by Router.
type ServableContext interface {
Context // minimal set of methods for middlewares and handler
RoutableContext // minimal set for routing. These methods should not be accessed in middlewares/handlers
// Reset resets the context after request completes. It must be called along
// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
@ -193,21 +191,6 @@ type (
Reset(r *http.Request, w http.ResponseWriter)
}
context struct {
request *http.Request
response *Response
path string
pnames []string
pvalues []string
query url.Values
handler HandlerFunc
store Map
echo *Echo
logger Logger
lock sync.RWMutex
}
)
const (
// ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain.
// Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests.
@ -221,39 +204,95 @@ const (
defaultIndent = " "
)
func (c *context) writeContentType(value string) {
// DefaultContext is default implementation of Context interface and can be embedded into structs to compose
// new Contexts with extended/modified behaviour.
type DefaultContext struct {
request *http.Request
response *Response
route RouteInfo
path string
// pathParams holds path/uri parameters determined by Router. Lifecycle is handled by Echo to reduce allocations.
pathParams *PathParams
// currentParams hold path parameters set by non-Echo implementation (custom middlewares, handlers) during the lifetime of Request.
// Lifecycle is not handle by Echo and could have excess allocations per served Request
currentParams PathParams
query url.Values
store Map
echo *Echo
lock sync.RWMutex
}
// NewDefaultContext creates new instance of DefaultContext.
// Argument pathParamAllocSize must be value that is stored in Echo.contextPathParamAllocSize field and is used
// to preallocate PathParams slice.
func NewDefaultContext(e *Echo, pathParamAllocSize int) *DefaultContext {
p := make(PathParams, pathParamAllocSize)
return &DefaultContext{
pathParams: &p,
store: make(Map),
echo: e,
}
}
// Reset resets the context after request completes. It must be called along
// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
// See `Echo#ServeHTTP()`
func (c *DefaultContext) Reset(r *http.Request, w http.ResponseWriter) {
c.request = r
c.response.reset(w)
c.query = nil
c.store = nil
c.route = nil
c.path = ""
// NOTE: Don't reset because it has to have length of c.echo.contextPathParamAllocSize at all times
*c.pathParams = (*c.pathParams)[:0]
c.currentParams = nil
}
func (c *DefaultContext) writeContentType(value string) {
header := c.Response().Header()
if header.Get(HeaderContentType) == "" {
header.Set(HeaderContentType, value)
}
}
func (c *context) Request() *http.Request {
// Request returns `*http.Request`.
func (c *DefaultContext) Request() *http.Request {
return c.request
}
func (c *context) SetRequest(r *http.Request) {
// SetRequest sets `*http.Request`.
func (c *DefaultContext) SetRequest(r *http.Request) {
c.request = r
}
func (c *context) Response() *Response {
// Response returns `*Response`.
func (c *DefaultContext) Response() *Response {
return c.response
}
func (c *context) SetResponse(r *Response) {
// SetResponse sets `*Response`.
func (c *DefaultContext) SetResponse(r *Response) {
c.response = r
}
func (c *context) IsTLS() bool {
// IsTLS returns true if HTTP connection is TLS otherwise false.
func (c *DefaultContext) IsTLS() bool {
return c.request.TLS != nil
}
func (c *context) IsWebSocket() bool {
// IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
func (c *DefaultContext) IsWebSocket() bool {
upgrade := c.request.Header.Get(HeaderUpgrade)
return strings.EqualFold(upgrade, "websocket")
}
func (c *context) Scheme() string {
// Scheme returns the HTTP protocol scheme, `http` or `https`.
func (c *DefaultContext) Scheme() string {
// Can't use `r.Request.URL.Scheme`
// See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0
if c.IsTLS() {
@ -274,7 +313,10 @@ func (c *context) Scheme() string {
return "http"
}
func (c *context) RealIP() string {
// RealIP returns the client's network address based on `X-Forwarded-For`
// or `X-Real-IP` request header.
// The behavior can be configured using `Echo#IPExtractor`.
func (c *DefaultContext) RealIP() string {
if c.echo != nil && c.echo.IPExtractor != nil {
return c.echo.IPExtractor(c.request)
}
@ -293,85 +335,116 @@ func (c *context) RealIP() string {
return ra
}
func (c *context) Path() string {
// Path returns the registered path for the handler.
func (c *DefaultContext) Path() string {
return c.path
}
func (c *context) SetPath(p string) {
// SetPath sets the registered path for the handler.
func (c *DefaultContext) SetPath(p string) {
c.path = p
}
func (c *context) Param(name string) string {
for i, n := range c.pnames {
if i < len(c.pvalues) {
if n == name {
return c.pvalues[i]
}
}
}
return ""
// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route.
// In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases.
func (c *DefaultContext) RouteInfo() RouteInfo {
return c.route
}
func (c *context) ParamNames() []string {
return c.pnames
// SetRouteInfo sets the route info of this request to the context.
func (c *DefaultContext) SetRouteInfo(ri RouteInfo) {
c.route = ri
}
func (c *context) SetParamNames(names ...string) {
c.pnames = names
l := len(names)
if *c.echo.maxParam < l {
*c.echo.maxParam = l
// RawPathParams returns raw path pathParams value. Allocation of PathParams is handled by Context.
func (c *DefaultContext) RawPathParams() *PathParams {
return c.pathParams
}
if len(c.pvalues) < l {
// Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them,
// probably those values will be overriden in a Context#SetParamValues
newPvalues := make([]string, l)
copy(newPvalues, c.pvalues)
c.pvalues = newPvalues
}
// SetRawPathParams replaces any existing param values with new values for this context lifetime (request).
//
// DO NOT USE!
// Do not set any other value than what you got from RawPathParams as allocation of PathParams is handled by Context.
// If you mess up size of pathParams size your application will panic/crash during routing
func (c *DefaultContext) SetRawPathParams(params *PathParams) {
c.pathParams = params
}
func (c *context) ParamValues() []string {
return c.pvalues[:len(c.pnames)]
// PathParam returns path parameter by name.
func (c *DefaultContext) PathParam(name string) string {
if c.currentParams != nil {
return c.currentParams.Get(name, "")
}
func (c *context) SetParamValues(values ...string) {
// NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times
// It will brake the Router#Find code
limit := len(values)
if limit > *c.echo.maxParam {
limit = *c.echo.maxParam
}
for i := 0; i < limit; i++ {
c.pvalues[i] = values[i]
}
return c.pathParams.Get(name, "")
}
func (c *context) QueryParam(name string) string {
// PathParamDefault does not exist as expecting empty path param makes no sense
// PathParams returns path parameter values.
func (c *DefaultContext) PathParams() PathParams {
if c.currentParams != nil {
return c.currentParams
}
result := make(PathParams, len(*c.pathParams))
copy(result, *c.pathParams)
return result
}
// SetPathParams sets path parameters for current request.
func (c *DefaultContext) SetPathParams(params PathParams) {
c.currentParams = params
}
// QueryParam returns the query param for the provided name.
func (c *DefaultContext) QueryParam(name string) string {
if c.query == nil {
c.query = c.request.URL.Query()
}
return c.query.Get(name)
}
func (c *context) QueryParams() url.Values {
// QueryParamDefault returns the query param or default value for the provided name.
// Note: QueryParamDefault does not distinguish if form had no value by that name or value was empty string
func (c *DefaultContext) QueryParamDefault(name, defaultValue string) string {
value := c.QueryParam(name)
if value == "" {
value = defaultValue
}
return value
}
// QueryParams returns the query parameters as `url.Values`.
func (c *DefaultContext) QueryParams() url.Values {
if c.query == nil {
c.query = c.request.URL.Query()
}
return c.query
}
func (c *context) QueryString() string {
// QueryString returns the URL query string.
func (c *DefaultContext) QueryString() string {
return c.request.URL.RawQuery
}
func (c *context) FormValue(name string) string {
// FormValue returns the form field value for the provided name.
func (c *DefaultContext) FormValue(name string) string {
return c.request.FormValue(name)
}
func (c *context) FormParams() (url.Values, error) {
// FormValueDefault returns the form field value or default value for the provided name.
// Note: FormValueDefault does not distinguish if form had no value by that name or value was empty string
func (c *DefaultContext) FormValueDefault(name, defaultValue string) string {
value := c.FormValue(name)
if value == "" {
value = defaultValue
}
return value
}
// FormValues returns the form field values as `url.Values`.
func (c *DefaultContext) FormValues() (url.Values, error) {
if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) {
if err := c.request.ParseMultipartForm(defaultMemory); err != nil {
return nil, err
@ -384,7 +457,8 @@ func (c *context) FormParams() (url.Values, error) {
return c.request.Form, nil
}
func (c *context) FormFile(name string) (*multipart.FileHeader, error) {
// FormFile returns the multipart form file for the provided name.
func (c *DefaultContext) FormFile(name string) (*multipart.FileHeader, error) {
f, fh, err := c.request.FormFile(name)
if err != nil {
return nil, err
@ -393,30 +467,36 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) {
return fh, nil
}
func (c *context) MultipartForm() (*multipart.Form, error) {
// MultipartForm returns the multipart form.
func (c *DefaultContext) MultipartForm() (*multipart.Form, error) {
err := c.request.ParseMultipartForm(defaultMemory)
return c.request.MultipartForm, err
}
func (c *context) Cookie(name string) (*http.Cookie, error) {
// Cookie returns the named cookie provided in the request.
func (c *DefaultContext) Cookie(name string) (*http.Cookie, error) {
return c.request.Cookie(name)
}
func (c *context) SetCookie(cookie *http.Cookie) {
// SetCookie adds a `Set-Cookie` header in HTTP response.
func (c *DefaultContext) SetCookie(cookie *http.Cookie) {
http.SetCookie(c.Response(), cookie)
}
func (c *context) Cookies() []*http.Cookie {
// Cookies returns the HTTP cookies sent with the request.
func (c *DefaultContext) Cookies() []*http.Cookie {
return c.request.Cookies()
}
func (c *context) Get(key string) interface{} {
// Get retrieves data from the context.
func (c *DefaultContext) Get(key string) interface{} {
c.lock.RLock()
defer c.lock.RUnlock()
return c.store[key]
}
func (c *context) Set(key string, val interface{}) {
// Set saves data in the context.
func (c *DefaultContext) Set(key string, val interface{}) {
c.lock.Lock()
defer c.lock.Unlock()
@ -426,18 +506,24 @@ func (c *context) Set(key string, val interface{}) {
c.store[key] = val
}
func (c *context) Bind(i interface{}) error {
return c.echo.Binder.Bind(i, c)
// Bind binds the request body into provided type `i`. The default binder
// does it based on Content-Type header.
func (c *DefaultContext) Bind(i interface{}) error {
return c.echo.Binder.Bind(c, i)
}
func (c *context) Validate(i interface{}) error {
// Validate validates provided `i`. It is usually called after `Context#Bind()`.
// Validator must be registered using `Echo#Validator`.
func (c *DefaultContext) Validate(i interface{}) error {
if c.echo.Validator == nil {
return ErrValidatorNotRegistered
}
return c.echo.Validator.Validate(i)
}
func (c *context) Render(code int, name string, data interface{}) (err error) {
// Render renders a template with data and sends a text/html response with status
// code. Renderer must be registered using `Echo.Renderer`.
func (c *DefaultContext) Render(code int, name string, data interface{}) (err error) {
if c.echo.Renderer == nil {
return ErrRendererNotRegistered
}
@ -448,19 +534,22 @@ func (c *context) Render(code int, name string, data interface{}) (err error) {
return c.HTMLBlob(code, buf.Bytes())
}
func (c *context) HTML(code int, html string) (err error) {
// HTML sends an HTTP response with status code.
func (c *DefaultContext) HTML(code int, html string) (err error) {
return c.HTMLBlob(code, []byte(html))
}
func (c *context) HTMLBlob(code int, b []byte) (err error) {
// HTMLBlob sends an HTTP blob response with status code.
func (c *DefaultContext) HTMLBlob(code int, b []byte) (err error) {
return c.Blob(code, MIMETextHTMLCharsetUTF8, b)
}
func (c *context) String(code int, s string) (err error) {
// String sends a string response with status code.
func (c *DefaultContext) String(code int, s string) (err error) {
return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s))
}
func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error) {
func (c *DefaultContext) jsonPBlob(code int, callback string, i interface{}) (err error) {
indent := ""
if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty {
indent = defaultIndent
@ -479,13 +568,14 @@ func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error
return
}
func (c *context) json(code int, i interface{}, indent string) error {
func (c *DefaultContext) json(code int, i interface{}, indent string) error {
c.writeContentType(MIMEApplicationJSONCharsetUTF8)
c.response.Status = code
return c.echo.JSONSerializer.Serialize(c, i, indent)
}
func (c *context) JSON(code int, i interface{}) (err error) {
// JSON sends a JSON response with status code.
func (c *DefaultContext) JSON(code int, i interface{}) (err error) {
indent := ""
if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty {
indent = defaultIndent
@ -493,19 +583,25 @@ func (c *context) JSON(code int, i interface{}) (err error) {
return c.json(code, i, indent)
}
func (c *context) JSONPretty(code int, i interface{}, indent string) (err error) {
// JSONPretty sends a pretty-print JSON with status code.
func (c *DefaultContext) JSONPretty(code int, i interface{}, indent string) (err error) {
return c.json(code, i, indent)
}
func (c *context) JSONBlob(code int, b []byte) (err error) {
// JSONBlob sends a JSON blob response with status code.
func (c *DefaultContext) JSONBlob(code int, b []byte) (err error) {
return c.Blob(code, MIMEApplicationJSONCharsetUTF8, b)
}
func (c *context) JSONP(code int, callback string, i interface{}) (err error) {
// JSONP sends a JSONP response with status code. It uses `callback` to construct
// the JSONP payload.
func (c *DefaultContext) JSONP(code int, callback string, i interface{}) (err error) {
return c.jsonPBlob(code, callback, i)
}
func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) {
// JSONPBlob sends a JSONP blob response with status code. It uses `callback`
// to construct the JSONP payload.
func (c *DefaultContext) JSONPBlob(code int, callback string, b []byte) (err error) {
c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8)
c.response.WriteHeader(code)
if _, err = c.response.Write([]byte(callback + "(")); err != nil {
@ -518,7 +614,7 @@ func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) {
return
}
func (c *context) xml(code int, i interface{}, indent string) (err error) {
func (c *DefaultContext) xml(code int, i interface{}, indent string) (err error) {
c.writeContentType(MIMEApplicationXMLCharsetUTF8)
c.response.WriteHeader(code)
enc := xml.NewEncoder(c.response)
@ -531,7 +627,8 @@ func (c *context) xml(code int, i interface{}, indent string) (err error) {
return enc.Encode(i)
}
func (c *context) XML(code int, i interface{}) (err error) {
// XML sends an XML response with status code.
func (c *DefaultContext) XML(code int, i interface{}) (err error) {
indent := ""
if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty {
indent = defaultIndent
@ -539,11 +636,13 @@ func (c *context) XML(code int, i interface{}) (err error) {
return c.xml(code, i, indent)
}
func (c *context) XMLPretty(code int, i interface{}, indent string) (err error) {
// XMLPretty sends a pretty-print XML with status code.
func (c *DefaultContext) XMLPretty(code int, i interface{}, indent string) (err error) {
return c.xml(code, i, indent)
}
func (c *context) XMLBlob(code int, b []byte) (err error) {
// XMLBlob sends an XML blob response with status code.
func (c *DefaultContext) XMLBlob(code int, b []byte) (err error) {
c.writeContentType(MIMEApplicationXMLCharsetUTF8)
c.response.WriteHeader(code)
if _, err = c.response.Write([]byte(xml.Header)); err != nil {
@ -553,39 +652,86 @@ func (c *context) XMLBlob(code int, b []byte) (err error) {
return
}
func (c *context) Blob(code int, contentType string, b []byte) (err error) {
// Blob sends a blob response with status code and content type.
func (c *DefaultContext) Blob(code int, contentType string, b []byte) (err error) {
c.writeContentType(contentType)
c.response.WriteHeader(code)
_, err = c.response.Write(b)
return
}
func (c *context) Stream(code int, contentType string, r io.Reader) (err error) {
// Stream sends a streaming response with status code and content type.
func (c *DefaultContext) Stream(code int, contentType string, r io.Reader) (err error) {
c.writeContentType(contentType)
c.response.WriteHeader(code)
_, err = io.Copy(c.response, r)
return
}
func (c *context) Attachment(file, name string) error {
// File sends a response with the content of the file.
func (c *DefaultContext) File(file string) error {
return fsFile(c, file, c.echo.Filesystem)
}
// FileFS serves file from given file system.
//
// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
// including `assets/images` as their prefix.
func (c *DefaultContext) FileFS(file string, filesystem fs.FS) error {
return fsFile(c, file, filesystem)
}
func fsFile(c Context, file string, filesystem fs.FS) error {
f, err := filesystem.Open(file)
if err != nil {
return ErrNotFound
}
defer f.Close()
fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect.
f, err = filesystem.Open(file)
if err != nil {
return ErrNotFound
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
return err
}
}
ff, ok := f.(io.ReadSeeker)
if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
return nil
}
// Attachment sends a response as attachment, prompting client to save the file.
func (c *DefaultContext) Attachment(file, name string) error {
return c.contentDisposition(file, name, "attachment")
}
func (c *context) Inline(file, name string) error {
// Inline sends a response as inline, opening the file in the browser.
func (c *DefaultContext) Inline(file, name string) error {
return c.contentDisposition(file, name, "inline")
}
func (c *context) contentDisposition(file, name, dispositionType string) error {
func (c *DefaultContext) contentDisposition(file, name, dispositionType string) error {
c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name))
return c.File(file)
}
func (c *context) NoContent(code int) error {
// NoContent sends a response with no body and a status code.
func (c *DefaultContext) NoContent(code int) error {
c.response.WriteHeader(code)
return nil
}
func (c *context) Redirect(code int, url string) error {
// Redirect redirects the request to a provided URL with status code.
func (c *DefaultContext) Redirect(code int, url string) error {
if code < 300 || code > 308 {
return ErrInvalidRedirectCode
}
@ -594,45 +740,7 @@ func (c *context) Redirect(code int, url string) error {
return nil
}
func (c *context) Error(err error) {
c.echo.HTTPErrorHandler(err, c)
}
func (c *context) Echo() *Echo {
// Echo returns the `Echo` instance.
func (c *DefaultContext) Echo() *Echo {
return c.echo
}
func (c *context) Handler() HandlerFunc {
return c.handler
}
func (c *context) SetHandler(h HandlerFunc) {
c.handler = h
}
func (c *context) Logger() Logger {
res := c.logger
if res != nil {
return res
}
return c.echo.Logger
}
func (c *context) SetLogger(l Logger) {
c.logger = l
}
func (c *context) Reset(r *http.Request, w http.ResponseWriter) {
c.request = r
c.response.reset(w)
c.query = nil
c.handler = NotFoundHandler
c.store = nil
c.path = ""
c.pnames = nil
c.logger = nil
// NOTE: Don't reset because it has to have length c.echo.maxParam at all times
for i := 0; i < *c.echo.maxParam; i++ {
c.pvalues[i] = ""
}
}

View File

@ -1,33 +0,0 @@
//go:build !go1.16
// +build !go1.16
package echo
import (
"net/http"
"os"
"path/filepath"
)
func (c *context) File(file string) (err error) {
f, err := os.Open(file)
if err != nil {
return NotFoundHandler(c)
}
defer f.Close()
fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.Join(file, indexPage)
f, err = os.Open(file)
if err != nil {
return NotFoundHandler(c)
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
return
}
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f)
return
}

View File

@ -1,52 +0,0 @@
//go:build go1.16
// +build go1.16
package echo
import (
"errors"
"io"
"io/fs"
"net/http"
"path/filepath"
)
func (c *context) File(file string) error {
return fsFile(c, file, c.echo.Filesystem)
}
// FileFS serves file from given file system.
//
// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
// including `assets/images` as their prefix.
func (c *context) FileFS(file string, filesystem fs.FS) error {
return fsFile(c, file, filesystem)
}
func fsFile(c Context, file string, filesystem fs.FS) error {
f, err := filesystem.Open(file)
if err != nil {
return ErrNotFound
}
defer f.Close()
fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect.
f, err = filesystem.Open(file)
if err != nil {
return ErrNotFound
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
return err
}
}
ff, ok := f.(io.ReadSeeker)
if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
return nil
}

View File

@ -1,135 +0,0 @@
//go:build go1.16
// +build go1.16
package echo
import (
"github.com/stretchr/testify/assert"
"io/fs"
"net/http"
"net/http/httptest"
"os"
"testing"
)
func TestContext_File(t *testing.T) {
var testCases = []struct {
name string
whenFile string
whenFS fs.FS
expectStatus int
expectStartsWith []byte
expectError string
}{
{
name: "ok, from default file system",
whenFile: "_fixture/images/walle.png",
whenFS: nil,
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "ok, from custom file system",
whenFile: "walle.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, not existent file",
whenFile: "not.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: nil,
expectError: "code=404, message=Not Found",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
if tc.whenFS != nil {
e.Filesystem = tc.whenFS
}
handler := func(ec Context) error {
return ec.(*context).File(tc.whenFile)
}
req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
assert.Equal(t, tc.expectStatus, rec.Code)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, body)
})
}
}
func TestContext_FileFS(t *testing.T) {
var testCases = []struct {
name string
whenFile string
whenFS fs.FS
expectStatus int
expectStartsWith []byte
expectError string
}{
{
name: "ok",
whenFile: "walle.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, not existent file",
whenFile: "not.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: nil,
expectError: "code=404, message=Not Found",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
handler := func(ec Context) error {
return ec.(*context).FileFS(tc.whenFile, tc.whenFS)
}
req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
assert.Equal(t, tc.expectStatus, rec.Code)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, body)
})
}
}

File diff suppressed because it is too large Load Diff

1004
echo.go

File diff suppressed because it is too large Load Diff

View File

@ -1,62 +0,0 @@
//go:build !go1.16
// +build !go1.16
package echo
import (
"net/http"
"net/url"
"os"
"path/filepath"
)
type filesystem struct {
}
func createFilesystem() filesystem {
return filesystem{}
}
// Static registers a new route with path prefix to serve static files from the
// provided root directory.
func (e *Echo) Static(prefix, root string) *Route {
if root == "" {
root = "." // For security we want to restrict to CWD.
}
return e.static(prefix, root, e.GET)
}
func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route {
h := func(c Context) error {
p, err := url.PathUnescape(c.Param("*"))
if err != nil {
return err
}
name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security
fi, err := os.Stat(name)
if err != nil {
// The access path does not exist
return NotFoundHandler(c)
}
// If the request is for a directory and does not end with "/"
p = c.Request().URL.Path // path must not be empty.
if fi.IsDir() && p[len(p)-1] != '/' {
// Redirect to ends with "/"
return c.Redirect(http.StatusMovedPermanently, p+"/")
}
return c.File(name)
}
// Handle added routes based on trailing slash:
// /prefix => exact route "/prefix" + any route "/prefix/*"
// /prefix/ => only any route "/prefix/*"
if prefix != "" {
if prefix[len(prefix)-1] == '/' {
// Only add any route for intentional trailing slash
return get(prefix+"*", h)
}
get(prefix, h)
}
return get(prefix+"/*", h)
}

View File

@ -1,169 +0,0 @@
//go:build go1.16
// +build go1.16
package echo
import (
"fmt"
"io/fs"
"net/http"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
)
type filesystem struct {
// Filesystem is file system used by Static and File handlers to access files.
// Defaults to os.DirFS(".")
//
// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
// including `assets/images` as their prefix.
Filesystem fs.FS
}
func createFilesystem() filesystem {
return filesystem{
Filesystem: newDefaultFS(),
}
}
// Static registers a new route with path prefix to serve static files from the provided root directory.
func (e *Echo) Static(pathPrefix, fsRoot string) *Route {
subFs := MustSubFS(e.Filesystem, fsRoot)
return e.Add(
http.MethodGet,
pathPrefix+"*",
StaticDirectoryHandler(subFs, false),
)
}
// StaticFS registers a new route with path prefix to serve static files from the provided file system.
//
// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
// including `assets/images` as their prefix.
func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route {
return e.Add(
http.MethodGet,
pathPrefix+"*",
StaticDirectoryHandler(filesystem, false),
)
}
// StaticDirectoryHandler creates handler function to serve files from provided file system
// When disablePathUnescaping is set then file name from path is not unescaped and is served as is.
func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc {
return func(c Context) error {
p := c.Param("*")
if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice
tmpPath, err := url.PathUnescape(p)
if err != nil {
return fmt.Errorf("failed to unescape path variable: %w", err)
}
p = tmpPath
}
// fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid
name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/")))
fi, err := fs.Stat(fileSystem, name)
if err != nil {
return ErrNotFound
}
// If the request is for a directory and does not end with "/"
p = c.Request().URL.Path // path must not be empty.
if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' {
// Redirect to ends with "/"
return c.Redirect(http.StatusMovedPermanently, p+"/")
}
return fsFile(c, name, fileSystem)
}
}
// FileFS registers a new route with path to serve file from the provided file system.
func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route {
return e.GET(path, StaticFileHandler(file, filesystem), m...)
}
// StaticFileHandler creates handler function to serve file from provided file system
func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc {
return func(c Context) error {
return fsFile(c, file, filesystem)
}
}
// defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`.
// v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface.
// Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/`
// etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not
// allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to
// traverse up from current executable run path.
// NB: private because you really should use fs.FS implementation instances
type defaultFS struct {
prefix string
fs fs.FS
}
func newDefaultFS() *defaultFS {
dir, _ := os.Getwd()
return &defaultFS{
prefix: dir,
fs: nil,
}
}
func (fs defaultFS) Open(name string) (fs.File, error) {
if fs.fs == nil {
return os.Open(name)
}
return fs.fs.Open(name)
}
func subFS(currentFs fs.FS, root string) (fs.FS, error) {
root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows
if dFS, ok := currentFs.(*defaultFS); ok {
// we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS.
// fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we
// were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs
if isRelativePath(root) {
root = filepath.Join(dFS.prefix, root)
}
return &defaultFS{
prefix: root,
fs: os.DirFS(root),
}, nil
}
return fs.Sub(currentFs, root)
}
func isRelativePath(path string) bool {
if path == "" {
return true
}
if path[0] == '/' {
return false
}
if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 {
// https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names
// https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats
return false
}
return true
}
// MustSubFS creates sub FS from current filesystem or panic on failure.
// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules.
//
// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with
// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to
// create sub fs which uses necessary prefix for directory path.
func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS {
subFs, err := subFS(currentFs, fsRoot)
if err != nil {
panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err))
}
return subFs
}

View File

@ -1,265 +0,0 @@
//go:build go1.16
// +build go1.16
package echo
import (
"github.com/stretchr/testify/assert"
"io/fs"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)
func TestEcho_StaticFS(t *testing.T) {
var testCases = []struct {
name string
givenPrefix string
givenFs fs.FS
givenFsRoot string
whenURL string
expectStatus int
expectHeaderLocation string
expectBodyStartsWith string
}{
{
name: "ok",
givenPrefix: "/images",
givenFs: os.DirFS("./_fixture/images"),
whenURL: "/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "ok, from sub fs",
givenPrefix: "/images",
givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"),
whenURL: "/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "No file",
givenPrefix: "/images",
givenFs: os.DirFS("_fixture/scripts"),
whenURL: "/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory",
givenPrefix: "/images",
givenFs: os.DirFS("_fixture/images"),
whenURL: "/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory Redirect",
givenPrefix: "/",
givenFs: os.DirFS("_fixture/"),
whenURL: "/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory Redirect with non-root path",
givenPrefix: "/static",
givenFs: os.DirFS("_fixture"),
whenURL: "/static",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/static/",
expectBodyStartsWith: "",
},
{
name: "Prefixed directory 404 (request URL without slash)",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
givenFs: os.DirFS("_fixture"),
whenURL: "/folder", // no trailing slash
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Prefixed directory redirect (without slash redirect to slash)",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
givenFs: os.DirFS("_fixture"),
whenURL: "/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory with index.html",
givenPrefix: "/",
givenFs: os.DirFS("_fixture"),
whenURL: "/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/",
givenFs: os.DirFS("_fixture"),
whenURL: "/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets",
givenFs: os.DirFS("_fixture"),
whenURL: "/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Sub-directory with index.html",
givenPrefix: "/",
givenFs: os.DirFS("_fixture"),
whenURL: "/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
givenFs: os.DirFS("_fixture/"),
whenURL: `/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
givenFs: os.DirFS("_fixture/"),
whenURL: `/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
tmpFs := tc.givenFs
if tc.givenFsRoot != "" {
tmpFs = MustSubFS(tmpFs, tc.givenFsRoot)
}
e.StaticFS(tc.givenPrefix, tmpFs)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String()
if tc.expectBodyStartsWith != "" {
assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
} else {
assert.Equal(t, "", body)
}
if tc.expectHeaderLocation != "" {
assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
} else {
_, ok := rec.Result().Header["Location"]
assert.False(t, ok)
}
})
}
}
func TestEcho_FileFS(t *testing.T) {
var testCases = []struct {
name string
whenPath string
whenFile string
whenFS fs.FS
givenURL string
expectCode int
expectStartsWith []byte
}{
{
name: "ok",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "walle.png",
givenURL: "/walle",
expectCode: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, requesting invalid path",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "walle.png",
givenURL: "/walle.png",
expectCode: http.StatusNotFound,
expectStartsWith: []byte(`{"message":"Not Found"}`),
},
{
name: "nok, serving not existent file from filesystem",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "not-existent.png",
givenURL: "/walle",
expectCode: http.StatusNotFound,
expectStartsWith: []byte(`{"message":"Not Found"}`),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectCode, rec.Code)
body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, body)
})
}
}
func TestEcho_StaticPanic(t *testing.T) {
var testCases = []struct {
name string
givenRoot string
expectError string
}{
{
name: "panics for ../",
givenRoot: "../assets",
expectError: "can not create sub FS, invalid root given, err: sub ../assets: invalid name",
},
{
name: "panics for /",
givenRoot: "/assets",
expectError: "can not create sub FS, invalid root given, err: sub /assets: invalid name",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.Filesystem = os.DirFS("./")
assert.PanicsWithError(t, tc.expectError, func() {
e.Static("../assets", tc.givenRoot)
})
})
}
}

File diff suppressed because it is too large Load Diff

15
go.mod
View File

@ -1,24 +1,19 @@
module github.com/labstack/echo/v4
module github.com/labstack/echo/v5
go 1.17
require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/labstack/gommon v0.3.1
github.com/golang-jwt/jwt/v4 v4.2.0
github.com/stretchr/testify v1.7.0
github.com/valyala/fasttemplate v1.2.1
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/mattn/go-colorable v0.1.11 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/text v0.3.3 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)

34
go.sum
View File

@ -1,14 +1,10 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o=
github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM=
github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs=
github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU=
github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -18,25 +14,15 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4=
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY=
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M=
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

164
group.go
View File

@ -1,98 +1,121 @@
package echo
import (
"io/fs"
"net/http"
)
type (
// Group is a set of sub-routes for a specified route. It can be used for inner
// routes that share a common middleware or functionality that should be separate
// from the parent echo instance while still inheriting from it.
Group struct {
common
type Group struct {
host string
prefix string
middleware []MiddlewareFunc
echo *Echo
}
)
// Use implements `Echo#Use()` for sub-routes within the Group.
// Group middlewares are not executed on request when there is no matching route found.
func (g *Group) Use(middleware ...MiddlewareFunc) {
g.middleware = append(g.middleware, middleware...)
if len(g.middleware) == 0 {
return
}
// Allow all requests to reach the group as they might get dropped if router
// doesn't find a match, making none of the group middleware process.
g.Any("", NotFoundHandler)
g.Any("/*", NotFoundHandler)
}
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group.
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error.
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodConnect, path, h, m...)
}
// DELETE implements `Echo#DELETE()` for sub-routes within the Group.
func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error.
func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodDelete, path, h, m...)
}
// GET implements `Echo#GET()` for sub-routes within the Group.
func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error.
func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodGet, path, h, m...)
}
// HEAD implements `Echo#HEAD()` for sub-routes within the Group.
func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error.
func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodHead, path, h, m...)
}
// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group.
func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error.
func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodOptions, path, h, m...)
}
// PATCH implements `Echo#PATCH()` for sub-routes within the Group.
func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error.
func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPatch, path, h, m...)
}
// POST implements `Echo#POST()` for sub-routes within the Group.
func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error.
func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPost, path, h, m...)
}
// PUT implements `Echo#PUT()` for sub-routes within the Group.
func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error.
func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPut, path, h, m...)
}
// TRACE implements `Echo#TRACE()` for sub-routes within the Group.
func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error.
func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodTrace, path, h, m...)
}
// Any implements `Echo#Any()` for sub-routes within the Group.
func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
routes := make([]*Route, len(methods))
for i, m := range methods {
routes[i] = g.Add(m, path, handler, middleware...)
// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error.
func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
errs := make([]error, 0)
ris := make(Routes, 0)
for _, m := range methods {
ri, err := g.AddRoute(Route{
Method: m,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
errs = append(errs, err)
continue
}
return routes
ris = append(ris, ri)
}
if len(errs) > 0 {
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ris
}
// Match implements `Echo#Match()` for sub-routes within the Group.
func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
routes := make([]*Route, len(methods))
for i, m := range methods {
routes[i] = g.Add(m, path, handler, middleware...)
// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error.
func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
errs := make([]error, 0)
ris := make(Routes, 0)
for _, m := range methods {
ri, err := g.AddRoute(Route{
Method: m,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
errs = append(errs, err)
continue
}
return routes
ris = append(ris, ri)
}
if len(errs) > 0 {
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ris
}
// Group creates a new sub-group with prefix and optional sub-group-level middleware.
// Important! Group middlewares are only executed in case there was exact route match and not
// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add
// a catch-all route `/*` for the group which handler returns always 404
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
@ -102,18 +125,57 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
return
}
// File implements `Echo#File()` for sub-routes within the Group.
func (g *Group) File(path, file string) {
g.file(path, file, g.GET)
// Static implements `Echo#Static()` for sub-routes within the Group.
func (g *Group) Static(pathPrefix, fsRoot string) RouteInfo {
subFs := MustSubFS(g.echo.Filesystem, fsRoot)
return g.StaticFS(pathPrefix, subFs)
}
// Add implements `Echo#Add()` for sub-routes within the Group.
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
// Combine into a new slice to avoid accidentally passing the same slice for
// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group.
//
// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
// including `assets/images` as their prefix.
func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) RouteInfo {
return g.Add(
http.MethodGet,
pathPrefix+"*",
StaticDirectoryHandler(filesystem, false),
)
}
// FileFS implements `Echo#FileFS()` for sub-routes within the Group.
func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo {
return g.GET(path, StaticFileHandler(file, filesystem), m...)
}
// File implements `Echo#File()` for sub-routes within the Group. Panics on error.
func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
handler := func(c Context) error {
return c.File(file)
}
return g.Add(http.MethodGet, path, handler, middleware...)
}
// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
ri, err := g.AddRoute(Route{
Method: method,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ri
}
// AddRoute registers a new Routable with Router
func (g *Group) AddRoute(route Routable) (RouteInfo, error) {
// Combine middleware into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls.
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
m = append(m, middleware...)
return g.echo.add(g.host, method, g.prefix+path, handler, m...)
groupRoute := route.ForGroup(g.prefix, append([]MiddlewareFunc{}, g.middleware...))
return g.echo.add(g.host, groupRoute)
}

View File

@ -1,9 +0,0 @@
//go:build !go1.16
// +build !go1.16
package echo
// Static implements `Echo#Static()` for sub-routes within the Group.
func (g *Group) Static(prefix, root string) {
g.static(prefix, root, g.GET)
}

View File

@ -1,33 +0,0 @@
//go:build go1.16
// +build go1.16
package echo
import (
"io/fs"
"net/http"
)
// Static implements `Echo#Static()` for sub-routes within the Group.
func (g *Group) Static(pathPrefix, fsRoot string) {
subFs := MustSubFS(g.echo.Filesystem, fsRoot)
g.StaticFS(pathPrefix, subFs)
}
// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group.
//
// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
// including `assets/images` as their prefix.
func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) {
g.Add(
http.MethodGet,
pathPrefix+"*",
StaticDirectoryHandler(filesystem, false),
)
}
// FileFS implements `Echo#FileFS()` for sub-routes within the Group.
func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route {
return g.GET(path, StaticFileHandler(file, filesystem), m...)
}

View File

@ -1,106 +0,0 @@
//go:build go1.16
// +build go1.16
package echo
import (
"github.com/stretchr/testify/assert"
"io/fs"
"net/http"
"net/http/httptest"
"os"
"testing"
)
func TestGroup_FileFS(t *testing.T) {
var testCases = []struct {
name string
whenPath string
whenFile string
whenFS fs.FS
givenURL string
expectCode int
expectStartsWith []byte
}{
{
name: "ok",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "walle.png",
givenURL: "/assets/walle",
expectCode: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, requesting invalid path",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "walle.png",
givenURL: "/assets/walle.png",
expectCode: http.StatusNotFound,
expectStartsWith: []byte(`{"message":"Not Found"}`),
},
{
name: "nok, serving not existent file from filesystem",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "not-existent.png",
givenURL: "/assets/walle",
expectCode: http.StatusNotFound,
expectStartsWith: []byte(`{"message":"Not Found"}`),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
g := e.Group("/assets")
g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectCode, rec.Code)
body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, body)
})
}
}
func TestGroup_StaticPanic(t *testing.T) {
var testCases = []struct {
name string
givenRoot string
expectError string
}{
{
name: "panics for ../",
givenRoot: "../images",
expectError: "can not create sub FS, invalid root given, err: sub ../images: invalid name",
},
{
name: "panics for /",
givenRoot: "/images",
expectError: "can not create sub FS, invalid root given, err: sub /images: invalid name",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.Filesystem = os.DirFS("./")
g := e.Group("/assets")
assert.PanicsWithError(t, tc.expectError, func() {
g.Static("/images", tc.givenRoot)
})
})
}
}

View File

@ -1,31 +1,70 @@
package echo
import (
"github.com/stretchr/testify/assert"
"io/fs"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
// TODO: Fix me
func TestGroup(t *testing.T) {
g := New().Group("/group")
h := func(Context) error { return nil }
g.CONNECT("/", h)
g.DELETE("/", h)
g.GET("/", h)
g.HEAD("/", h)
g.OPTIONS("/", h)
g.PATCH("/", h)
g.POST("/", h)
g.PUT("/", h)
g.TRACE("/", h)
g.Any("/", h)
g.Match([]string{http.MethodGet, http.MethodPost}, "/", h)
g.Static("/static", "/tmp")
g.File("/walle", "_fixture/images//walle.png")
func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) {
e := New()
called := false
mw := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
called = true
return c.NoContent(http.StatusTeapot)
}
}
// even though group has middleware it will not be executed when there are no routes under that group
_ = e.Group("/group", mw)
status, body := request(http.MethodGet, "/group/nope", e)
assert.Equal(t, http.StatusNotFound, status)
assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
assert.False(t, called)
}
func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) {
e := New()
called := false
mw := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
called = true
return c.NoContent(http.StatusTeapot)
}
}
// even though group has middleware and routes when we have no match on some route the middlewares for that
// group will not be executed
g := e.Group("/group", mw)
g.GET("/yes", handlerFunc)
status, body := request(http.MethodGet, "/group/nope", e)
assert.Equal(t, http.StatusNotFound, status)
assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
assert.False(t, called)
}
func TestGroup_multiLevelGroup(t *testing.T) {
e := New()
api := e.Group("/api")
users := api.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
status, body := request(http.MethodGet, "/api/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroupFile(t *testing.T) {
@ -92,11 +131,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
}
m2 := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
return c.String(http.StatusOK, c.Path())
return c.String(http.StatusOK, c.RouteInfo().Path())
}
}
h := func(c Context) error {
return c.String(http.StatusOK, c.Path())
return c.String(http.StatusOK, c.RouteInfo().Path())
}
g.Use(m1)
g.GET("/help", h, m2)
@ -119,3 +158,535 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
assert.Equal(t, "/*", m)
}
func TestGroup_CONNECT(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.CONNECT("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodConnect, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodConnect, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_DELETE(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.DELETE("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodDelete, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodDelete, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_HEAD(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.HEAD("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodHead, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodHead+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodHead, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_OPTIONS(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.OPTIONS("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodOptions, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodOptions, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_PATCH(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.PATCH("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPatch, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPatch, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_POST(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.POST("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPost, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPost+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPost, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_PUT(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.PUT("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPut, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPut+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPut, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_TRACE(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.TRACE("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodTrace, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodTrace, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_Any(t *testing.T) {
e := New()
users := e.Group("/users")
ris := users.Any("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Len(t, ris, 11)
for _, m := range methods {
status, body := request(m, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_AnyWithErrors(t *testing.T) {
e := New()
users := e.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
errs := func() (errs []error) {
defer func() {
if r := recover(); r != nil {
if tmpErr, ok := r.([]error); ok {
errs = tmpErr
return
}
panic(r)
}
}()
users.Any("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
return nil
}()
assert.Len(t, errs, 1)
assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
for _, m := range methods {
status, body := request(m, "/users/activate", e)
expect := http.StatusTeapot
if m == http.MethodGet {
expect = http.StatusOK
}
assert.Equal(t, expect, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_Match(t *testing.T) {
e := New()
myMethods := []string{http.MethodGet, http.MethodPost}
users := e.Group("/users")
ris := users.Match(myMethods, "/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Len(t, ris, 2)
for _, m := range myMethods {
status, body := request(m, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_MatchWithErrors(t *testing.T) {
e := New()
users := e.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
myMethods := []string{http.MethodGet, http.MethodPost}
errs := func() (errs []error) {
defer func() {
if r := recover(); r != nil {
if tmpErr, ok := r.([]error); ok {
errs = tmpErr
return
}
panic(r)
}
}()
users.Match(myMethods, "/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
return nil
}()
assert.Len(t, errs, 1)
assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
for _, m := range myMethods {
status, body := request(m, "/users/activate", e)
expect := http.StatusTeapot
if m == http.MethodGet {
expect = http.StatusOK
}
assert.Equal(t, expect, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_Static(t *testing.T) {
e := New()
g := e.Group("/books")
ri := g.Static("/download", "_fixture")
assert.Equal(t, http.MethodGet, ri.Method())
assert.Equal(t, "/books/download*", ri.Path())
assert.Equal(t, "GET:/books/download*", ri.Name())
assert.Equal(t, []string{"*"}, ri.Params())
req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
body := rec.Body.String()
assert.True(t, strings.HasPrefix(body, "<!doctype html>"))
}
func TestGroup_StaticMultiTest(t *testing.T) {
var testCases = []struct {
name string
givenPrefix string
givenRoot string
whenURL string
expectStatus int
expectHeaderLocation string
expectBodyStartsWith string
}{
{
name: "ok",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/test/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "ok, without prefix",
givenPrefix: "",
givenRoot: "_fixture/images",
whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png`
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "nok, without prefix does not serve dir index",
givenPrefix: "",
givenRoot: "_fixture/images",
whenURL: "/test/", // `/test` + `*` creates route `/test*`
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "No file",
givenPrefix: "/images",
givenRoot: "_fixture/scripts",
whenURL: "/test/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/test/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory Redirect",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory Redirect with non-root path",
givenPrefix: "/static",
givenRoot: "_fixture",
whenURL: "/test/static",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/static/",
expectBodyStartsWith: "",
},
{
name: "Prefixed directory 404 (request URL without slash)",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
givenRoot: "_fixture",
whenURL: "/test/folder", // no trailing slash
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Prefixed directory redirect (without slash redirect to slash)",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
givenRoot: "_fixture",
whenURL: "/test/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/",
givenRoot: "_fixture",
whenURL: "/test/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets",
givenRoot: "_fixture",
whenURL: "/test/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Sub-directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/test/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/test/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
g := e.Group("/test")
g.Static(tc.givenPrefix, tc.givenRoot)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String()
if tc.expectBodyStartsWith != "" {
assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
} else {
assert.Equal(t, "", body)
}
if tc.expectHeaderLocation != "" {
assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
} else {
_, ok := rec.Result().Header["Location"]
assert.False(t, ok)
}
})
}
}
func TestGroup_FileFS(t *testing.T) {
var testCases = []struct {
name string
whenPath string
whenFile string
whenFS fs.FS
givenURL string
expectCode int
expectStartsWith []byte
}{
{
name: "ok",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "walle.png",
givenURL: "/assets/walle",
expectCode: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, requesting invalid path",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "walle.png",
givenURL: "/assets/walle.png",
expectCode: http.StatusNotFound,
expectStartsWith: []byte(`{"message":"Not Found"}`),
},
{
name: "nok, serving not existent file from filesystem",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "not-existent.png",
givenURL: "/assets/walle",
expectCode: http.StatusNotFound,
expectStartsWith: []byte(`{"message":"Not Found"}`),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
g := e.Group("/assets")
g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectCode, rec.Code)
body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, body)
})
}
}
func TestGroup_StaticPanic(t *testing.T) {
var testCases = []struct {
name string
givenRoot string
expectError string
}{
{
name: "panics for ../",
givenRoot: "../images",
expectError: "can not create sub FS, invalid root given, err: sub ../images: invalid name",
},
{
name: "panics for /",
givenRoot: "/images",
expectError: "can not create sub FS, invalid root given, err: sub /images: invalid name",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.Filesystem = os.DirFS("./")
g := e.Group("/assets")
assert.PanicsWithError(t, tc.expectError, func() {
g.Static("/images", tc.givenRoot)
})
})
}
}

74
httperror.go Normal file
View File

@ -0,0 +1,74 @@
package echo
import (
"errors"
"fmt"
"net/http"
)
// Errors
var (
ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType)
ErrNotFound = NewHTTPError(http.StatusNotFound)
ErrUnauthorized = NewHTTPError(http.StatusUnauthorized)
ErrForbidden = NewHTTPError(http.StatusForbidden)
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests)
ErrBadRequest = NewHTTPError(http.StatusBadRequest)
ErrBadGateway = NewHTTPError(http.StatusBadGateway)
ErrInternalServerError = NewHTTPError(http.StatusInternalServerError)
ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout)
ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable)
ErrValidatorNotRegistered = errors.New("validator not registered")
ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
ErrInvalidListenerNetwork = errors.New("invalid listener network")
)
// HTTPError represents an error that occurred while handling a request.
type HTTPError struct {
Code int `json:"-"`
Message interface{} `json:"message"`
Internal error `json:"-"` // Stores the error returned by an external dependency
}
// NewHTTPError creates a new HTTPError instance.
func NewHTTPError(code int, message ...interface{}) *HTTPError { // FIXME: this need cleanup - why vararg if [0] is only used?
he := &HTTPError{Code: code, Message: http.StatusText(code)}
if len(message) > 0 {
he.Message = message[0]
}
return he
}
// NewHTTPErrorWithInternal creates a new HTTPError instance with internal error set.
func NewHTTPErrorWithInternal(code int, internalError error, message ...interface{}) *HTTPError {
he := NewHTTPError(code, message...)
he.Internal = internalError
return he
}
// Error makes it compatible with `error` interface.
func (he *HTTPError) Error() string {
if he.Internal == nil {
return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message)
}
return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal)
}
// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field
func (he *HTTPError) WithInternal(err error) *HTTPError {
return &HTTPError{
Code: he.Code,
Message: he.Message,
Internal: err,
}
}
// Unwrap satisfies the Go 1.13 error wrapper interface.
func (he *HTTPError) Unwrap() error {
return he.Internal
}

52
httperror_test.go Normal file
View File

@ -0,0 +1,52 @@
package echo
import (
"errors"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)
func TestHTTPError(t *testing.T) {
t.Run("non-internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
assert.Equal(t, "code=400, message=map[code:12]", err.Error())
})
t.Run("internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
err = err.WithInternal(errors.New("internal error"))
assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error())
})
}
func TestNewHTTPErrorWithInternal(t *testing.T) {
he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"), "test message")
assert.Equal(t, "code=400, message=test message, internal=test", he.Error())
}
func TestNewHTTPErrorWithInternal_noCustomMessage(t *testing.T) {
he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"))
assert.Equal(t, "code=400, message=Bad Request, internal=test", he.Error())
}
func TestHTTPError_Unwrap(t *testing.T) {
t.Run("non-internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
assert.Nil(t, errors.Unwrap(err))
})
t.Run("internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
err = err.WithInternal(errors.New("internal error"))
assert.Equal(t, "internal error", errors.Unwrap(err).Error())
})
}

11
json.go
View File

@ -23,9 +23,16 @@ func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string
func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error {
err := json.NewDecoder(c.Request().Body).Decode(i)
if ute, ok := err.(*json.UnmarshalTypeError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err)
return NewHTTPErrorWithInternal(
http.StatusBadRequest,
err,
fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset),
)
} else if se, ok := err.(*json.SyntaxError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest,
err,
fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error()),
)
}
return err
}

View File

@ -14,7 +14,7 @@ func TestDefaultJSONCodec_Encode(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
c := e.NewContext(req, rec).(*DefaultContext)
assert := testify.New(t)
@ -40,7 +40,7 @@ func TestDefaultJSONCodec_Encode(t *testing.T) {
req = httptest.NewRequest(http.MethodPost, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
c = e.NewContext(req, rec).(*DefaultContext)
err = enc.Serialize(c, user{1, "Jon Snow"}, " ")
if assert.NoError(err) {
assert.Equal(userJSONPretty+"\n", rec.Body.String())
@ -53,7 +53,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
c := e.NewContext(req, rec).(*DefaultContext)
assert := testify.New(t)
@ -81,7 +81,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
var userUnmarshalSyntaxError = user{}
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
c = e.NewContext(req, rec).(*DefaultContext)
err = enc.Deserialize(c, &userUnmarshalSyntaxError)
assert.IsType(&HTTPError{}, err)
assert.EqualError(err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value")
@ -93,7 +93,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) {
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
c = e.NewContext(req, rec).(*DefaultContext)
err = enc.Deserialize(c, &userUnmarshalTypeError)
assert.IsType(&HTTPError{}, err)
assert.EqualError(err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string")

175
log.go
View File

@ -1,41 +1,148 @@
package echo
import (
"bytes"
"io"
"github.com/labstack/gommon/log"
"strconv"
"sync"
"time"
)
type (
// Logger defines the logging interface.
Logger interface {
Output() io.Writer
SetOutput(w io.Writer)
Prefix() string
SetPrefix(p string)
Level() log.Lvl
SetLevel(v log.Lvl)
SetHeader(h string)
Print(i ...interface{})
Printf(format string, args ...interface{})
Printj(j log.JSON)
Debug(i ...interface{})
Debugf(format string, args ...interface{})
Debugj(j log.JSON)
Info(i ...interface{})
Infof(format string, args ...interface{})
Infoj(j log.JSON)
Warn(i ...interface{})
Warnf(format string, args ...interface{})
Warnj(j log.JSON)
Error(i ...interface{})
Errorf(format string, args ...interface{})
Errorj(j log.JSON)
Fatal(i ...interface{})
Fatalj(j log.JSON)
Fatalf(format string, args ...interface{})
Panic(i ...interface{})
Panicj(j log.JSON)
Panicf(format string, args ...interface{})
//-----------------------------------------------------------------------------
// Example for Zap (https://github.com/uber-go/zap)
//func main() {
// e := echo.New()
// logger, _ := zap.NewProduction()
// e.Logger = &ZapLogger{logger: logger}
//}
//type ZapLogger struct {
// logger *zap.Logger
//}
//
//func (l *ZapLogger) Write(p []byte) (n int, err error) {
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
// l.logger.Info(string(p), zap.String("subsystem", "echo")) // naively log everything as string message.
// return len(p), nil
//}
//
//func (l *ZapLogger) Error(err error) {
// l.logger.Error(err.Error(), zap.Error(err), zap.String("subsystem", "echo"))
//}
//-----------------------------------------------------------------------------
// Example for Zerolog (https://github.com/rs/zerolog)
//func main() {
// e := echo.New()
// logger := zerolog.New(os.Stdout)
// e.Logger = &ZeroLogger{logger: &logger}
//}
//
//type ZeroLogger struct {
// logger *zerolog.Logger
//}
//
//func (l *ZeroLogger) Write(p []byte) (n int, err error) {
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
// l.logger.Info().Str("subsystem", "echo").Msg(string(p)) // naively log everything as string message.
// return len(p), nil
//}
//
//func (l *ZeroLogger) Error(err error) {
// l.logger.Error().Str("subsystem", "echo").Err(err).Msg(err.Error())
//}
//-----------------------------------------------------------------------------
// Example for Logrus (https://github.com/sirupsen/logrus)
//func main() {
// e := echo.New()
// e.Logger = &LogrusLogger{logger: logrus.New()}
//}
//
//type LogrusLogger struct {
// logger *logrus.Logger
//}
//
//func (l *LogrusLogger) Write(p []byte) (n int, err error) {
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Info(string(p)) // naively log everything as string message.
// return len(p), nil
//}
//
//func (l *LogrusLogger) Error(err error) {
// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Error(err)
//}
// Logger defines the logging interface that Echo uses internally in few places.
// For logging in handlers use your own logger instance (dependency injected or package/public variable) from logging framework of your choice.
type Logger interface {
// Write provides writer interface for http.Server `ErrorLog` and for logging startup messages.
// `http.Server.ErrorLog` logs errors from accepting connections, unexpected behavior from handlers,
// and underlying FileSystem errors.
// `logger` middleware will use this method to write its JSON payload.
Write(p []byte) (n int, err error)
// Error logs the error
Error(err error)
}
// jsonLogger is similar logger formatting implementation as `v4` had. It is not particularly fast or efficient. Only
// goal it to exist is to have somewhat backwards compatibility with `v4` for Echo internals logging formatting.
// It is not meant for logging in handlers/middlewares. Use some real logging library for those cases.
type jsonLogger struct {
writer io.Writer
bufferPool sync.Pool
lock sync.Mutex
timeNow func() time.Time
}
func newJSONLogger(writer io.Writer) *jsonLogger {
return &jsonLogger{
writer: writer,
bufferPool: sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 256))
},
},
timeNow: time.Now,
}
}
func (l *jsonLogger) Write(p []byte) (n int, err error) {
pLen := len(p)
if pLen >= 2 && // naively try to avoid JSON values to be wrapped into message
(p[0] == '{' && p[pLen-2] == '}' && p[pLen-1] == '\n') ||
(p[0] == '{' && p[pLen-1] == '}') {
return l.write(p)
}
// we log with WARN level as we have no idea what that message level should be. From Echo perspective this method is
// called when we pass Echo logger to http.Server.ErrorLog and there are problems inside http.Server - which probably
// deserves at least WARN level.
return l.printf("INFO", string(p))
}
func (l *jsonLogger) Error(err error) {
_, _ = l.printf("ERROR", err.Error())
}
func (l *jsonLogger) printf(level string, message string) (n int, err error) {
buf := l.bufferPool.Get().(*bytes.Buffer)
buf.Reset()
defer l.bufferPool.Put(buf)
buf.WriteString(`{"time":"`)
buf.WriteString(l.timeNow().Format(time.RFC3339Nano))
buf.WriteString(`","level":"`)
buf.WriteString(level)
buf.WriteString(`","prefix":"echo","message":`)
buf.WriteString(strconv.Quote(message))
buf.WriteString("}\n")
return l.write(buf.Bytes())
}
func (l *jsonLogger) write(p []byte) (int, error) {
l.lock.Lock()
defer l.lock.Unlock()
return l.writer.Write(p)
}
)

87
log_test.go Normal file
View File

@ -0,0 +1,87 @@
package echo
import (
"bytes"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type noOpLogger struct {
}
func (l *noOpLogger) Write(p []byte) (n int, err error) {
return 0, err
}
func (l *noOpLogger) Error(err error) {
}
func TestJsonLogger_Write(t *testing.T) {
var testCases = []struct {
name string
when []byte
expect string
}{
{
name: "ok, write non JSONlike message",
when: []byte("version: %v, build: %v"),
expect: `{"time":"2021-09-07T20:09:37Z","level":"INFO","prefix":"echo","message":"version: %v, build: %v"}` + "\n",
},
{
name: "ok, write quoted message",
when: []byte(`version: "%v"`),
expect: `{"time":"2021-09-07T20:09:37Z","level":"INFO","prefix":"echo","message":"version: \"%v\""}` + "\n",
},
{
name: "ok, write JSON",
when: []byte(`{"version": 123}` + "\n"),
expect: `{"version": 123}` + "\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
logger := newJSONLogger(buf)
logger.timeNow = func() time.Time {
return time.Unix(1631045377, 0).UTC()
}
_, err := logger.Write(tc.when)
result := buf.String()
assert.Equal(t, tc.expect, result)
assert.NoError(t, err)
})
}
}
func TestJsonLogger_Error(t *testing.T) {
var testCases = []struct {
name string
whenError error
expect string
}{
{
name: "ok",
whenError: ErrForbidden,
expect: `{"time":"2021-09-07T20:09:37Z","level":"ERROR","prefix":"echo","message":"code=403, message=Forbidden"}` + "\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
logger := newJSONLogger(buf)
logger.timeNow = func() time.Time {
return time.Unix(1631045377, 0).UTC()
}
logger.Error(tc.whenError)
result := buf.String()
assert.Equal(t, tc.expect, result)
})
}
}

11
middleware/DEVELOPMENT.md Normal file
View File

@ -0,0 +1,11 @@
# Development Guidelines for middlewares
## Best practices:
* Do not use `panic` in middleware creator functions in case of invalid configuration.
* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead
because previous middlewares up in call chain could have logic for dealing with returned errors.
* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they
want to create middleware with panics or with returning errors on configuration errors.
* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same.

View File

@ -1,64 +1,59 @@
package middleware
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"strconv"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
)
type (
// BasicAuthConfig defines the config for BasicAuth middleware.
BasicAuthConfig struct {
// BasicAuthConfig defines the config for BasicAuthWithConfig middleware.
type BasicAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Validator is a function to validate BasicAuth credentials.
// Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers
// this function would be called once for each header until first valid result is returned
// Required.
Validator BasicAuthValidator
// Realm is a string to define realm attribute of BasicAuth.
// Realm is a string to define realm attribute of BasicAuthWithConfig.
// Default value "Restricted".
Realm string
}
// BasicAuthValidator defines a function to validate BasicAuth credentials.
BasicAuthValidator func(string, string, echo.Context) (bool, error)
)
// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials.
type BasicAuthValidator func(c echo.Context, user string, password string) (bool, error)
const (
basic = "basic"
defaultRealm = "Restricted"
)
var (
// DefaultBasicAuthConfig is the default BasicAuth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{
Skipper: DefaultSkipper,
Realm: defaultRealm,
}
)
// BasicAuth returns an BasicAuth middleware.
//
// For valid credentials it calls the next handler.
// For missing or invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig
c.Validator = fn
return BasicAuthWithConfig(c)
return BasicAuthWithConfig(BasicAuthConfig{Validator: fn})
}
// BasicAuthWithConfig returns an BasicAuth middleware with config.
// See `BasicAuth()`.
// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration
func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Validator == nil {
panic("echo: basic-auth middleware requires a validator function")
return nil, errors.New("echo basic-auth middleware requires a validator function")
}
if config.Skipper == nil {
config.Skipper = DefaultBasicAuthConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.Realm == "" {
config.Realm = defaultRealm
@ -70,27 +65,31 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return next(c)
}
auth := c.Request().Header.Get(echo.HeaderAuthorization)
var lastError error
l := len(basic)
if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err != nil {
return err
for _, auth := range c.Request().Header[echo.HeaderAuthorization] {
if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
continue
}
cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
valid, err := config.Validator(cred[:i], cred[i+1:], c)
if err != nil {
return err
b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
if errDecode != nil {
lastError = fmt.Errorf("invalid basic auth value: %w", errDecode)
continue
}
idx := bytes.IndexByte(b, ':')
if idx >= 0 {
valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:]))
if errValidate != nil {
lastError = errValidate
} else if valid {
return next(c)
}
break
}
}
if lastError != nil {
return lastError
}
realm := defaultRealm
@ -102,5 +101,5 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
return echo.ErrUnauthorized
}
}
}, nil
}

View File

@ -2,70 +2,157 @@ package middleware
import (
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestBasicAuth(t *testing.T) {
validatorFunc := func(c echo.Context, u, p string) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
}
if u == "error" {
return false, errors.New(p)
}
return false, nil
}
defaultConfig := BasicAuthConfig{Validator: validatorFunc}
var testCases = []struct {
name string
givenConfig BasicAuthConfig
whenAuth []string
expectHeader string
expectErr string
}{
{
name: "ok",
givenConfig: defaultConfig,
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, multiple",
givenConfig: defaultConfig,
whenAuth: []string{
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
basic + " NOT_BASE64",
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
},
},
{
name: "nok, invalid Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, not base64 Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
expectErr: "invalid basic auth value: illegal base64 data at input byte 3",
},
{
name: "nok, missing Authorization header",
givenConfig: defaultConfig,
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "ok, realm",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, realm, case-insensitive header scheme",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "nok, realm, invalid Authorization header",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm="someRealm"`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, validator func returns an error",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
expectErr: "my_error",
},
{
name: "ok, skipped",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool {
return true
}},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
f := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" {
config := tc.givenConfig
mw, err := config.ToMiddleware()
assert.NoError(t, err)
h := mw(func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
if len(tc.whenAuth) != 0 {
for _, a := range tc.whenAuth {
req.Header.Add(echo.HeaderAuthorization, a)
}
}
err = h(c)
if tc.expectErr != "" {
assert.Equal(t, http.StatusOK, res.Code)
assert.EqualError(t, err, tc.expectErr)
} else {
assert.Equal(t, http.StatusTeapot, res.Code)
assert.NoError(t, err)
}
if tc.expectHeader != "" {
assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
}
})
}
}
func TestBasicAuth_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BasicAuth(nil)
assert.NotNil(t, mw)
})
mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) {
return true, nil
})
assert.NotNil(t, mw)
}
return false, nil
}
h := BasicAuth(f)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
func TestBasicAuthWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
assert.NotNil(t, mw)
})
assert := assert.New(t)
// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
h = BasicAuthWithConfig(BasicAuthConfig{
Skipper: nil,
Validator: f,
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
// Invalid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) {
return true, nil
}})
assert.NotNil(t, mw)
}

View File

@ -3,17 +3,17 @@ package middleware
import (
"bufio"
"bytes"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
)
type (
// BodyDumpConfig defines the config for BodyDump middleware.
BodyDumpConfig struct {
type BodyDumpConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -23,51 +23,45 @@ type (
}
// BodyDumpHandler receives the request and response payload.
BodyDumpHandler func(echo.Context, []byte, []byte)
type BodyDumpHandler func(c echo.Context, reqBody []byte, resBody []byte)
bodyDumpResponseWriter struct {
type bodyDumpResponseWriter struct {
io.Writer
http.ResponseWriter
}
)
var (
// DefaultBodyDumpConfig is the default BodyDump middleware config.
DefaultBodyDumpConfig = BodyDumpConfig{
Skipper: DefaultSkipper,
}
)
// BodyDump returns a BodyDump middleware.
//
// BodyDump middleware captures the request and response payload and calls the
// registered handler.
func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
c := DefaultBodyDumpConfig
c.Handler = handler
return BodyDumpWithConfig(c)
return BodyDumpWithConfig(BodyDumpConfig{Handler: handler})
}
// BodyDumpWithConfig returns a BodyDump middleware with config.
// See: `BodyDump()`.
func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration
func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Handler == nil {
panic("echo: body-dump middleware requires a handler function")
return nil, errors.New("echo body-dump middleware requires a handler function")
}
if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper
config.Skipper = DefaultSkipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
// Request
reqBody := []byte{}
if c.Request().Body != nil { // Read
if c.Request().Body != nil {
reqBody, _ = ioutil.ReadAll(c.Request().Body)
}
c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset
@ -78,16 +72,14 @@ func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer}
c.Response().Writer = writer
if err = next(c); err != nil {
c.Error(err)
}
err := next(c)
// Callback
config.Handler(c, reqBody, resBody.Bytes())
return
}
return err
}
}, nil
}
func (w *bodyDumpResponseWriter) WriteHeader(code int) {

View File

@ -8,7 +8,7 @@ import (
"strings"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@ -28,31 +28,48 @@ func TestBodyDump(t *testing.T) {
requestBody := ""
responseBody := ""
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
})
}}.ToMiddleware()
assert.NoError(t, err)
assert := assert.New(t)
if assert.NoError(mw(h)(c)) {
assert.Equal(requestBody, hw)
assert.Equal(responseBody, hw)
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.String())
if assert.NoError(t, mw(h)(c)) {
assert.Equal(t, requestBody, hw)
assert.Equal(t, responseBody, hw)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.String())
}
// Must set default skipper
BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
}
func TestBodyDump_skipper(t *testing.T) {
e := echo.New()
isCalled := false
mw, err := BodyDumpConfig{
Skipper: func(c echo.Context) bool {
return true
},
})
Handler: func(c echo.Context, reqBody, resBody []byte) {
isCalled = true
},
}.ToMiddleware()
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
return errors.New("some error")
}
func TestBodyDumpFails(t *testing.T) {
err = mw(h)(c)
assert.EqualError(t, err, "some error")
assert.False(t, isCalled)
}
func TestBodyDump_fails(t *testing.T) {
e := echo.New()
hw := "Hello, World!"
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
@ -62,30 +79,37 @@ func TestBodyDumpFails(t *testing.T) {
return errors.New("some error")
}
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.EqualError(t, err, "some error")
assert.Equal(t, http.StatusOK, rec.Code)
if !assert.Error(t, mw(h)(c)) {
t.FailNow()
}
func TestBodyDumpWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{
mw := BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: nil,
})
assert.NotNil(t, mw)
})
assert.NotPanics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{
Skipper: func(c echo.Context) bool {
return true
},
Handler: func(c echo.Context, reqBody, resBody []byte) {
},
mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}})
assert.NotNil(t, mw)
})
}
func TestBodyDump_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BodyDump(nil)
assert.NotNil(t, mw)
})
if !assert.Error(t, mw(h)(c)) {
t.FailNow()
}
assert.NotPanics(t, func() {
BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
})
}

View File

@ -1,98 +1,83 @@
package middleware
import (
"fmt"
"io"
"sync"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/bytes"
"github.com/labstack/echo/v5"
)
type (
// BodyLimitConfig defines the config for BodyLimit middleware.
BodyLimitConfig struct {
// BodyLimitConfig defines the config for BodyLimitWithConfig middleware.
type BodyLimitConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Maximum allowed size for a request body, it can be specified
// as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P.
Limit string `yaml:"limit"`
limit int64
// LimitBytes is maximum allowed size in bytes for a request body
LimitBytes int64
}
limitedReader struct {
type limitedReader struct {
BodyLimitConfig
reader io.ReadCloser
read int64
context echo.Context
}
)
var (
// DefaultBodyLimitConfig is the default BodyLimit middleware config.
DefaultBodyLimitConfig = BodyLimitConfig{
Skipper: DefaultSkipper,
}
)
// BodyLimit returns a BodyLimit middleware.
//
// BodyLimit middleware sets the maximum allowed size for a request body, if the
// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
// response. The BodyLimit is determined based on both `Content-Length` request
// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it
// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request
// header and actual content read, which makes it super secure.
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
// G, T or P.
func BodyLimit(limit string) echo.MiddlewareFunc {
c := DefaultBodyLimitConfig
c.Limit = limit
return BodyLimitWithConfig(c)
func BodyLimit(limitBytes int64) echo.MiddlewareFunc {
return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes})
}
// BodyLimitWithConfig returns a BodyLimit middleware with config.
// See: `BodyLimit()`.
// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for
// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response.
// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which
// makes it super secure.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultBodyLimitConfig.Skipper
return toMiddlewareOrPanic(config)
}
limit, err := bytes.Parse(config.Limit)
if err != nil {
panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit))
// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration
func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
pool := sync.Pool{
New: func() interface{} {
return &limitedReader{BodyLimitConfig: config}
},
}
config.limit = limit
pool := limitedReaderPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
// Based on content length
if req.ContentLength > config.limit {
if req.ContentLength > config.LimitBytes {
return echo.ErrStatusRequestEntityTooLarge
}
// Based on content read
r := pool.Get().(*limitedReader)
r.Reset(req.Body, c)
r.Reset(c, req.Body)
defer pool.Put(r)
req.Body = r
return next(c)
}
}
}, nil
}
func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b)
r.read += int64(n)
if r.read > r.limit {
if r.read > r.LimitBytes {
return n, echo.ErrStatusRequestEntityTooLarge
}
return
@ -102,16 +87,8 @@ func (r *limitedReader) Close() error {
return r.reader.Close()
}
func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) {
func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) {
r.reader = reader
r.context = context
r.read = 0
}
func limitedReaderPool(c BodyLimitConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
return &limitedReader{BodyLimitConfig: c}
},
}
}

View File

@ -7,10 +7,143 @@ import (
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
// Based on content length (within limit)
mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
// Based on content read (overlimit)
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
assert.NoError(t, err)
he := mw(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Hello, World!", rec.Body.String())
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
assert.NoError(t, err)
he = mw(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
}
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
config := BodyLimitConfig{
Skipper: DefaultSkipper,
LimitBytes: 2,
}
reader := &limitedReader{
BodyLimitConfig: config,
reader: ioutil.NopCloser(bytes.NewReader(hw)),
context: e.NewContext(req, rec),
}
// read all should return ErrStatusRequestEntityTooLarge
_, err := ioutil.ReadAll(reader)
he := err.(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
reader.Reset(e.NewContext(req, rec), ioutil.NopCloser(bytes.NewReader(hw)))
n, err := reader.Read(bt)
assert.Equal(t, 2, n)
assert.Equal(t, nil, err)
}
func TestBodyLimit_skipper(t *testing.T) {
e := echo.New()
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
mw, err := BodyLimitConfig{
Skipper: func(c echo.Context) bool {
return true
},
LimitBytes: 2,
}.ToMiddleware()
assert.NoError(t, err)
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
func TestBodyLimitWithConfig(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB})
err := mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
func TestBodyLimit(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
@ -25,63 +158,10 @@ func TestBodyLimit(t *testing.T) {
return c.String(http.StatusOK, string(body))
}
assert := assert.New(t)
mw := BodyLimit(2 * MB)
// Based on content length (within limit)
if assert.NoError(BodyLimit("2M")(h)(c)) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.Bytes())
}
// Based on content length (overlimit)
he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
if assert.NoError(BodyLimit("2M")(h)(c)) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal("Hello, World!", rec.Body.String())
}
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
}
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
config := BodyLimitConfig{
Skipper: DefaultSkipper,
Limit: "2B",
limit: 2,
}
reader := &limitedReader{
BodyLimitConfig: config,
reader: ioutil.NopCloser(bytes.NewReader(hw)),
context: e.NewContext(req, rec),
}
// read all should return ErrStatusRequestEntityTooLarge
_, err := ioutil.ReadAll(reader)
he := err.(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec))
n, err := reader.Read(bt)
assert.Equal(t, 2, n)
assert.Equal(t, nil, err)
err := mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}

View File

@ -3,6 +3,7 @@ package middleware
import (
"bufio"
"compress/gzip"
"errors"
"io"
"io/ioutil"
"net"
@ -10,54 +11,49 @@ import (
"strings"
"sync"
"github.com/labstack/echo/v4"
)
type (
// GzipConfig defines the config for Gzip middleware.
GzipConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Gzip compression level.
// Optional. Default value -1.
Level int `yaml:"level"`
}
gzipResponseWriter struct {
io.Writer
http.ResponseWriter
wroteBody bool
}
"github.com/labstack/echo/v5"
)
const (
gzipScheme = "gzip"
)
var (
// DefaultGzipConfig is the default Gzip middleware config.
DefaultGzipConfig = GzipConfig{
Skipper: DefaultSkipper,
Level: -1,
}
)
// GzipConfig defines the config for Gzip middleware.
type GzipConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Gzip returns a middleware which compresses HTTP response using gzip compression
// scheme.
// Gzip compression level.
// Optional. Default value -1.
Level int
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
wroteBody bool
}
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
func Gzip() echo.MiddlewareFunc {
return GzipWithConfig(DefaultGzipConfig)
return GzipWithConfig(GzipConfig{})
}
// GzipWithConfig return Gzip middleware with config.
// See: `Gzip()`.
// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration
func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
return nil, errors.New("invalid gzip level")
}
if config.Level == 0 {
config.Level = DefaultGzipConfig.Level
config.Level = -1
}
pool := gzipCompressPool(config)
@ -98,7 +94,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
}
return next(c)
}
}
}, nil
}
func (w *gzipResponseWriter) WriteHeader(code int) {

View File

@ -3,94 +3,128 @@ package middleware
import (
"bytes"
"compress/gzip"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestGzip(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
func TestGzip_NoAcceptEncodingHeader(t *testing.T) {
// Skip if no Accept-Encoding header
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
assert.Equal("test", rec.Body.String())
err := h(c)
assert.NoError(t, err)
// Gzip
req = httptest.NewRequest(http.MethodGet, "/", nil)
assert.Equal(t, "test", rec.Body.String())
}
func TestMustGzipWithConfig_panics(t *testing.T) {
assert.Panics(t, func() {
GzipWithConfig(GzipConfig{Level: 999})
})
}
func TestGzip_AcceptEncodingHeader(t *testing.T) {
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
r, err := gzip.NewReader(rec.Body)
if assert.NoError(err) {
assert.NoError(t, err)
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal("test", buf.String())
assert.Equal(t, "test", buf.String())
}
chunkBuf := make([]byte, 5)
// Gzip chunked
req = httptest.NewRequest(http.MethodGet, "/", nil)
func TestGzip_chunked(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec = httptest.NewRecorder()
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c = e.NewContext(req, rec)
Gzip()(func(c echo.Context) error {
chunkChan := make(chan struct{})
waitChan := make(chan struct{})
h := Gzip()(func(c echo.Context) error {
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Transfer-Encoding", "chunked")
// Write and flush the first part of the data
c.Response().Write([]byte("test\n"))
c.Response().Write([]byte("first\n"))
c.Response().Flush()
// Read the first part of the data
assert.True(rec.Flushed)
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r.Reset(rec.Body)
_, err = io.ReadFull(r, chunkBuf)
assert.NoError(err)
assert.Equal("test\n", string(chunkBuf))
chunkChan <- struct{}{}
<-waitChan
// Write and flush the second part of the data
c.Response().Write([]byte("test\n"))
c.Response().Write([]byte("second\n"))
c.Response().Flush()
_, err = io.ReadFull(r, chunkBuf)
assert.NoError(err)
assert.Equal("test\n", string(chunkBuf))
chunkChan <- struct{}{}
<-waitChan
// Write the final part of the data and return
c.Response().Write([]byte("test"))
return nil
})(c)
c.Response().Write([]byte("third"))
chunkChan <- struct{}{}
return nil
})
go func() {
err := h(c)
chunkChan <- struct{}{}
assert.NoError(t, err)
}()
<-chunkChan // wait for first write
waitChan <- struct{}{}
<-chunkChan // wait for second write
waitChan <- struct{}{}
<-chunkChan // wait for final write in handler
<-chunkChan // wait for return from handler
time.Sleep(5 * time.Millisecond) // to have time for flushing
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r, err := gzip.NewReader(rec.Body)
assert.NoError(t, err)
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal("test", buf.String())
assert.Equal(t, "first\nsecond\nthird", buf.String())
}
func TestGzipNoContent(t *testing.T) {
func TestGzip_NoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
@ -106,7 +140,7 @@ func TestGzipNoContent(t *testing.T) {
}
}
func TestGzipEmpty(t *testing.T) {
func TestGzip_Empty(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
@ -127,7 +161,7 @@ func TestGzipEmpty(t *testing.T) {
}
}
func TestGzipErrorReturned(t *testing.T) {
func TestGzip_ErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Gzip())
e.GET("/", func(c echo.Context) error {
@ -141,31 +175,25 @@ func TestGzipErrorReturned(t *testing.T) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
e := echo.New()
// Invalid level
e.Use(GzipWithConfig(GzipConfig{Level: 12}))
e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("test"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "gzip")
func TestGzipWithConfig_invalidLevel(t *testing.T) {
mw, err := GzipConfig{Level: 12}.ToMiddleware()
assert.EqualError(t, err, "invalid gzip level")
assert.Nil(t, mw)
}
// Issue #806
func TestGzipWithStatic(t *testing.T) {
e := echo.New()
e.Filesystem = os.DirFS("../")
e.Use(Gzip())
e.Static("/test", "../_fixture/images")
e.Static("/test", "_fixture/images")
req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Data is written out in chunks when Content-Length == "", so only
// validate the content length if it's not set.

View File

@ -6,37 +6,36 @@ import (
"strconv"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
)
type (
// CORSConfig defines the config for CORS middleware.
CORSConfig struct {
type CORSConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// AllowOrigin defines a list of origins that may access the resource.
// Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"`
AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
AllowOriginFunc func(origin string) (bool, error)
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods.
// If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value
// from `Allow` header that echo.Router set into context.
AllowMethods []string `yaml:"allow_methods"`
AllowMethods []string
// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
// Optional. Default value []string{}.
AllowHeaders []string `yaml:"allow_headers"`
AllowHeaders []string
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
@ -45,28 +44,25 @@ type (
// Optional. Default value false.
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
// See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
AllowCredentials bool `yaml:"allow_credentials"`
AllowCredentials bool
// ExposeHeaders defines a whitelist headers that clients are allowed to
// access.
// Optional. Default value []string{}.
ExposeHeaders []string `yaml:"expose_headers"`
ExposeHeaders []string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
// Optional. Default value 0.
MaxAge int `yaml:"max_age"`
MaxAge int
}
)
var (
// DefaultCORSConfig is the default CORS middleware config.
DefaultCORSConfig = CORSConfig{
var DefaultCORSConfig = CORSConfig{
Skipper: DefaultSkipper,
AllowOrigins: []string{"*"},
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}
)
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
@ -74,9 +70,14 @@ func CORS() echo.MiddlewareFunc {
return CORSWithConfig(DefaultCORSConfig)
}
// CORSWithConfig returns a CORS middleware with config.
// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
// See: `CORS()`.
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration
func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCORSConfig.Skipper
@ -172,7 +173,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
checkPatterns := false
if allowOrigin == "" {
// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
if len(origin) <= (253+3+5) && strings.Contains(origin, "://") {
if len(origin) <= (5+3+253) && strings.Contains(origin, "://") {
checkPatterns = true
}
}
@ -230,5 +231,5 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
}
return c.NoContent(http.StatusNoContent)
}
}
}, nil
}

View File

@ -6,7 +6,7 @@ import (
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@ -17,7 +17,7 @@ func TestCORS(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := CORS()(echo.NotFoundHandler)
h := CORS()(func(c echo.Context) error { return echo.ErrNotFound })
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -26,7 +26,7 @@ func TestCORS(t *testing.T) {
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORS()(echo.NotFoundHandler)
h = CORS()(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
@ -38,7 +38,7 @@ func TestCORS(t *testing.T) {
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})(echo.NotFoundHandler)
})(func(c echo.Context) error { return echo.ErrNotFound })
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -55,7 +55,7 @@ func TestCORS(t *testing.T) {
AllowCredentials: true,
MaxAge: 3600,
})
h = cors(echo.NotFoundHandler)
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
@ -73,7 +73,7 @@ func TestCORS(t *testing.T) {
AllowCredentials: true,
MaxAge: 3600,
})
h = cors(echo.NotFoundHandler)
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
@ -90,7 +90,7 @@ func TestCORS(t *testing.T) {
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
})
h = cors(echo.NotFoundHandler)
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
@ -104,7 +104,7 @@ func TestCORS(t *testing.T) {
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"http://*.example.com"},
})
h = cors(echo.NotFoundHandler)
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -149,7 +149,7 @@ func Test_allowOriginScheme(t *testing.T) {
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern},
})
h := cors(echo.NotFoundHandler)
h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
if tt.expected {
@ -240,7 +240,7 @@ func Test_allowOriginSubdomain(t *testing.T) {
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern},
})
h := cors(echo.NotFoundHandler)
h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
if tt.expected {
@ -324,7 +324,9 @@ func TestCORSWithConfig_AllowMethods(t *testing.T) {
c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey)
}
h := cors(echo.NotFoundHandler)
h := cors(func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
})
h(c)
assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow))
@ -511,11 +513,11 @@ func Test_allowOriginFunc(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
cors := CORSWithConfig(CORSConfig{
AllowOriginFunc: allowOriginFunc,
})
h := cors(echo.NotFoundHandler)
err := h(c)
cors, err := CORSConfig{AllowOriginFunc: allowOriginFunc}.ToMiddleware()
assert.NoError(t, err)
h := cors(func(c echo.Context) error { return echo.ErrNotFound })
err = h(c)
expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil {

View File

@ -5,18 +5,16 @@ import (
"net/http"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/labstack/echo/v5"
)
type (
// CSRFConfig defines the config for CSRF middleware.
CSRFConfig struct {
type CSRFConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// TokenLength is the length of the generated token.
TokenLength uint8 `yaml:"token_length"`
TokenLength uint8
// Optional. Default value 32.
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
@ -30,46 +28,48 @@ type (
// - "header:X-CSRF-Token,query:csrf"
TokenLookup string `yaml:"token_lookup"`
// Generator defines a function to generate token.
// Optional. Defaults tp randomString(TokenLength).
Generator func() string
// Context key to store generated CSRF token into context.
// Optional. Default value "csrf".
ContextKey string `yaml:"context_key"`
ContextKey string
// Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf".
CookieName string `yaml:"cookie_name"`
CookieName string
// Domain of the CSRF cookie.
// Optional. Default value none.
CookieDomain string `yaml:"cookie_domain"`
CookieDomain string
// Path of the CSRF cookie.
// Optional. Default value none.
CookiePath string `yaml:"cookie_path"`
CookiePath string
// Max age (in seconds) of the CSRF cookie.
// Optional. Default value 86400 (24hr).
CookieMaxAge int `yaml:"cookie_max_age"`
CookieMaxAge int
// Indicates if CSRF cookie is secure.
// Optional. Default value false.
CookieSecure bool `yaml:"cookie_secure"`
CookieSecure bool
// Indicates if CSRF cookie is HTTP only.
// Optional. Default value false.
CookieHTTPOnly bool `yaml:"cookie_http_only"`
CookieHTTPOnly bool
// Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
CookieSameSite http.SameSite
}
)
// ErrCSRFInvalid is returned when CSRF check fails
var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
var (
// DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{
var DefaultCSRFConfig = CSRFConfig{
Skipper: DefaultSkipper,
TokenLength: 32,
TokenLookup: "header:" + echo.HeaderXCSRFToken,
@ -78,18 +78,20 @@ var (
CookieMaxAge: 86400,
CookieSameSite: http.SameSiteDefaultMode,
}
)
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
func CSRF() echo.MiddlewareFunc {
c := DefaultCSRFConfig
return CSRFWithConfig(c)
return CSRFWithConfig(DefaultCSRFConfig)
}
// CSRFWithConfig returns a CSRF middleware with config.
// See `CSRF()`.
// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration
func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCSRFConfig.Skipper
@ -97,6 +99,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength
}
if config.Generator == nil {
config.Generator = createRandomStringGenerator(config.TokenLength)
}
if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup
}
@ -113,9 +118,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
config.CookieSecure = true
}
extractors, err := createExtractors(config.TokenLookup, "")
extractors, err := createExtractors(config.TokenLookup)
if err != nil {
panic(err)
return nil, err
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -126,7 +131,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
token := ""
if k, err := c.Cookie(config.CookieName); err != nil {
token = random.String(config.TokenLength) // Generate token
token = config.Generator() // Generate token
} else {
token = k.Value // Reuse token
}
@ -157,17 +162,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if lastTokenErr != nil {
return lastTokenErr
} else if lastExtractorErr != nil {
// ugly part to preserve backwards compatible errors. someone could rely on them
if lastExtractorErr == errQueryExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string")
} else if lastExtractorErr == errFormExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter")
} else if lastExtractorErr == errHeaderExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header")
} else {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
}
return lastExtractorErr
return echo.ErrBadRequest.WithInternal(lastExtractorErr)
}
}
@ -197,7 +192,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return next(c)
}
}
}, nil
}
func validateCSRFToken(token, clientToken string) bool {

View File

@ -7,8 +7,7 @@ import (
"strings"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@ -23,6 +22,7 @@ func TestCSRF_tokenExtractors(t *testing.T) {
givenFormTokens map[string][]string
givenHeaderTokens map[string][]string
expectError string
expectToMiddlewareError string
}{
{
name: "ok, multiple token lookups sources, succeeds on last one",
@ -70,7 +70,7 @@ func TestCSRF_tokenExtractors(t *testing.T) {
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{},
expectError: "code=400, message=missing csrf token in the form parameter",
expectError: "code=400, message=Bad Request, internal=missing value in the form",
},
{
name: "ok, token from POST header",
@ -106,7 +106,7 @@ func TestCSRF_tokenExtractors(t *testing.T) {
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{},
expectError: "code=400, message=missing csrf token in request header",
expectError: "code=400, message=Bad Request, internal=missing value in request header",
},
{
name: "ok, token from PUT query param",
@ -142,7 +142,15 @@ func TestCSRF_tokenExtractors(t *testing.T) {
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{},
expectError: "code=400, message=missing csrf token in the query string",
expectError: "code=400, message=Bad Request, internal=missing value in the query string",
},
{
name: "nok, invalid TokenLookup",
whenTokenLookup: "q",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{},
expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q",
},
}
@ -186,16 +194,23 @@ func TestCSRF_tokenExtractors(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
config := CSRFConfig{
TokenLookup: tc.whenTokenLookup,
CookieName: tc.whenCookieName,
})
}
csrf, err := config.ToMiddleware()
if tc.expectToMiddlewareError != "" {
assert.EqualError(t, err, tc.expectToMiddlewareError)
return
} else if err != nil {
assert.NoError(t, err)
}
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
err := h(c)
err = h(c)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@ -219,6 +234,24 @@ func TestCSRF(t *testing.T) {
h(c)
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
}
func TestMustCSRFWithConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
TokenLength: 16,
})
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Generate CSRF token
h(c)
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
// Without CSRF cookie
req = httptest.NewRequest(http.MethodPost, "/", nil)
rec = httptest.NewRecorder()
@ -233,7 +266,7 @@ func TestCSRF(t *testing.T) {
assert.Error(t, h(c))
// Valid CSRF token
token := random.String(32)
token := randomString(16)
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
req.Header.Set(echo.HeaderXCSRFToken, token)
if assert.NoError(t, h(c)) {
@ -302,9 +335,10 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
csrf, err := CSRFConfig{
CookieSameSite: http.SameSiteNoneMode,
})
}.ToMiddleware()
assert.NoError(t, err)
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")

View File

@ -6,19 +6,17 @@ import (
"net/http"
"sync"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
)
type (
// DecompressConfig defines the config for Decompress middleware.
DecompressConfig struct {
type DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
}
)
// GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"
@ -28,14 +26,6 @@ type Decompressor interface {
gzipDecompressPool() sync.Pool
}
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
)
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}
@ -46,17 +36,21 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
// Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig)
return DecompressWithConfig(DecompressConfig{})
}
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration.
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration
func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
config.GzipDecompressPool = &DefaultGzipDecompressPool{}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -95,5 +89,5 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
return next(c)
}
}
}, nil
}

View File

@ -11,12 +11,37 @@ import (
"sync"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestDecompress(t *testing.T) {
e := echo.New()
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
// Decompress request body
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}
func TestDecompress_skippedIfNoHeader(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@ -26,39 +51,42 @@ func TestDecompress(t *testing.T) {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
err := h(c)
assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}
func TestDecompressDefaultConfig(t *testing.T) {
func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
h, err := DecompressConfig{}.ToMiddleware()
assert.NoError(t, err)
err = h(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})(c)
assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
}
func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
@ -67,11 +95,14 @@ func TestDecompressDefaultConfig(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
err := h(c)
assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
@ -82,7 +113,9 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err)
@ -99,7 +132,10 @@ func TestDecompressNoContent(t *testing.T) {
h := Decompress()(func(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
if assert.NoError(t, h(c)) {
err := h(c)
if assert.NoError(t, err) {
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes()))
@ -115,7 +151,9 @@ func TestDecompressErrorReturned(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
@ -132,7 +170,9 @@ func TestDecompressSkipper(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
@ -161,7 +201,9 @@ func TestDecompressPoolError(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)

View File

@ -1,9 +1,8 @@
package middleware
import (
"errors"
"fmt"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"net/textproto"
"strings"
)
@ -14,17 +13,27 @@ const (
extractorLimit = 20
)
var errHeaderExtractorValueMissing = errors.New("missing value in request header")
var errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
var errQueryExtractorValueMissing = errors.New("missing value in the query string")
var errParamExtractorValueMissing = errors.New("missing value in path params")
var errCookieExtractorValueMissing = errors.New("missing value in cookies")
var errFormExtractorValueMissing = errors.New("missing value in the form")
// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups
type ValueExtractorError struct {
message string
}
// Error returns errors text
func (e *ValueExtractorError) Error() string {
return e.message
}
var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"}
var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"}
var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"}
var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"}
var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"}
var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"}
// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
type ValuesExtractor func(c echo.Context) ([]string, error)
func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) {
func createExtractors(lookups string) ([]ValuesExtractor, error) {
if lookups == "" {
return nil, nil
}
@ -49,15 +58,6 @@ func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, err
prefix := ""
if len(parts) > 2 {
prefix = parts[2]
} else if authScheme != "" && parts[1] == echo.HeaderAuthorization {
// backwards compatibility for JWT and KeyAuth:
// * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer <token-value>" etc
// * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that
// behaviour for default values and Authorization header.
prefix = authScheme
if !strings.HasSuffix(prefix, " ") {
prefix += " "
}
}
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
}
@ -125,10 +125,9 @@ func valuesFromQuery(param string) ValuesExtractor {
func valuesFromParam(param string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
result := make([]string, 0)
paramVales := c.ParamValues()
for i, p := range c.ParamNames() {
if param == p {
result = append(result, paramVales[i])
for i, p := range c.PathParams() {
if param == p.Name {
result = append(result, p.Value)
if i >= extractorLimit-1 {
break
}

View File

@ -3,7 +3,7 @@ package middleware
import (
"bytes"
"fmt"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
"mime/multipart"
"net/http"
@ -13,27 +13,11 @@ import (
"testing"
)
type pathParam struct {
name string
value string
}
func setPathParams(c echo.Context, params []pathParam) {
names := make([]string, 0, len(params))
values := make([]string, 0, len(params))
for _, pp := range params {
names = append(names, pp.name)
values = append(values, pp.value)
}
c.SetParamNames(names...)
c.SetParamValues(values...)
}
func TestCreateExtractors(t *testing.T) {
var testCases = []struct {
name string
givenRequest func() *http.Request
givenPathParams []pathParam
givenPathParams echo.PathParams
whenLoopups string
expectValues []string
expectCreateError string
@ -74,8 +58,8 @@ func TestCreateExtractors(t *testing.T) {
},
{
name: "ok, param",
givenPathParams: []pathParam{
{name: "id", value: "123"},
givenPathParams: echo.PathParams{
{Name: "id", Value: "123"},
},
whenLoopups: "param:id",
expectValues: []string{"123"},
@ -105,12 +89,12 @@ func TestCreateExtractors(t *testing.T) {
req = tc.givenRequest()
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c := e.NewContext(req, rec).(echo.ServableContext)
if tc.givenPathParams != nil {
setPathParams(c, tc.givenPathParams)
c.SetRawPathParams(&tc.givenPathParams)
}
extractors, err := createExtractors(tc.whenLoopups, "")
extractors, err := createExtractors(tc.whenLoopups)
if tc.expectCreateError != "" {
assert.EqualError(t, err, tc.expectCreateError)
return
@ -317,19 +301,19 @@ func TestValuesFromQuery(t *testing.T) {
}
func TestValuesFromParam(t *testing.T) {
examplePathParams := []pathParam{
{name: "id", value: "123"},
{name: "gid", value: "456"},
{name: "gid", value: "789"},
examplePathParams := echo.PathParams{
{Name: "id", Value: "123"},
{Name: "gid", Value: "456"},
{Name: "gid", Value: "789"},
}
examplePathParams20 := make([]pathParam, 0)
examplePathParams20 := make(echo.PathParams, 0)
for i := 1; i < 25; i++ {
examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)})
examplePathParams20 = append(examplePathParams20, echo.PathParam{Name: "id", Value: fmt.Sprintf("%v", i)})
}
var testCases = []struct {
name string
givenPathParams []pathParam
givenPathParams echo.PathParams
whenName string
expectValues []string
expectError string
@ -377,9 +361,9 @@ func TestValuesFromParam(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c := e.NewContext(req, rec).(echo.ServableContext)
if tc.givenPathParams != nil {
setPathParams(c, tc.givenPathParams)
c.SetRawPathParams(&tc.givenPathParams)
}
extractor := valuesFromParam(tc.whenName)

View File

@ -1,36 +1,30 @@
//go:build go1.15
// +build go1.15
package middleware
import (
"errors"
"fmt"
"github.com/golang-jwt/jwt"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"net/http"
"reflect"
)
type (
// JWTConfig defines the config for JWT middleware.
JWTConfig struct {
type JWTConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc BeforeFunc
// SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next
// middleware or handler.
// SuccessHandler defines a function which is executed for a valid token.
SuccessHandler JWTSuccessHandler
// ErrorHandler defines a function which is executed for an invalid token.
// ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
// function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
// It may be used to define a custom JWT error.
ErrorHandler JWTErrorHandler
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
ErrorHandlerWithContext JWTErrorHandlerWithContext
//
// Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public JWT token value to request and continue with handler chain.
ErrorHandler JWTErrorHandlerWithContext
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to
// ignore the error (by returning `nil`).
@ -39,31 +33,10 @@ type (
// and continue. Some logic down the remaining execution chain needs to check that (public) token value then.
ContinueOnIgnoredError bool
// Signing key to validate token.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKeys is provided.
SigningKey interface{}
// Map of signing keys to validate token with kid field usage.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKey is provided.
SigningKeys map[string]interface{}
// Signing method used to check the token's signing algorithm.
// Optional. Default value HS256.
SigningMethod string
// Context key to store user information from the token into context.
// Optional. Default value "user".
ContextKey string
// Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation.
// Not used if custom ParseTokenFunc is set.
// Optional. Default value jwt.MapClaims
Claims jwt.Claims
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract token from the request.
// Optional. Default value "header:Authorization".
@ -79,7 +52,7 @@ type (
// - "cookie:<name>"
// - "form:<name>"
// Multiple sources example:
// - "header:Authorization,cookie:myowncookie"
// - "header:Authorization:Bearer ,cookie:myowncookie"
TokenLookup string
// TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context.
@ -88,63 +61,38 @@ type (
// You can also provide both if you want.
TokenLookupFuncs []ValuesExtractor
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// KeyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
// Used by default ParseTokenFunc implementation.
//
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither SigningKeys nor SigningKey is provided.
// Not used if custom ParseTokenFunc is set.
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
KeyFunc jwt.Keyfunc
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
// parsing fails or parsed token is invalid.
// Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library
ParseTokenFunc func(auth string, c echo.Context) (interface{}, error)
ParseTokenFunc func(c echo.Context, auth string) (interface{}, error)
}
// JWTSuccessHandler defines a function which is executed for a valid token.
JWTSuccessHandler func(c echo.Context)
type JWTSuccessHandler func(c echo.Context)
// JWTErrorHandler defines a function which is executed for an invalid token.
JWTErrorHandler func(err error) error
type JWTErrorHandler func(err error) error
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
JWTErrorHandlerWithContext func(err error, c echo.Context) error
)
type JWTErrorHandlerWithContext func(c echo.Context, err error) error
// Algorithms
const (
// AlgorithmHS256 is token signing algorithm
AlgorithmHS256 = "HS256"
)
// Errors
var (
ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
)
// ErrJWTMissing denotes an error raised when JWT token value could not be extracted from request
var ErrJWTMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing or malformed jwt")
// ErrJWTInvalid denotes an error raised when JWT token value is invalid or expired
var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
var (
// DefaultJWTConfig is the default JWT auth middleware config.
DefaultJWTConfig = JWTConfig{
var DefaultJWTConfig = JWTConfig{
Skipper: DefaultSkipper,
SigningMethod: AlgorithmHS256,
ContextKey: "user",
TokenLookup: "header:" + echo.HeaderAuthorization,
TokenLookupFuncs: nil,
AuthScheme: "Bearer",
Claims: jwt.MapClaims{},
KeyFunc: nil,
TokenLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
}
)
// JWT returns a JSON Web Token (JWT) auth middleware.
//
@ -153,48 +101,40 @@ var (
// For missing token, it returns "400 - Bad Request" error.
//
// See: https://jwt.io/introduction
// See `JWTConfig.TokenLookup`
func JWT(key interface{}) echo.MiddlewareFunc {
func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc {
c := DefaultJWTConfig
c.SigningKey = key
c.ParseTokenFunc = parseTokenFunc
return JWTWithConfig(c)
}
// JWTWithConfig returns a JWT auth middleware with config.
// See: `JWT()`.
// JWTWithConfig returns a JSON Web Token (JWT) auth middleware or panics if configuration is invalid.
//
// For valid token, it sets the user in context and calls next handler.
// For invalid token, it returns "401 - Unauthorized" error.
// For missing token, it returns "400 - Bad Request" error.
//
// See: https://jwt.io/introduction
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts JWTConfig to middleware or returns an error for invalid configuration
func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultJWTConfig.Skipper
}
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil {
panic("echo: jwt middleware requires signing key")
}
if config.SigningMethod == "" {
config.SigningMethod = DefaultJWTConfig.SigningMethod
if config.ParseTokenFunc == nil {
return nil, errors.New("echo jwt middleware requires parse token function")
}
if config.ContextKey == "" {
config.ContextKey = DefaultJWTConfig.ContextKey
}
if config.Claims == nil {
config.Claims = DefaultJWTConfig.Claims
}
if config.TokenLookup == "" && len(config.TokenLookupFuncs) == 0 {
config.TokenLookup = DefaultJWTConfig.TokenLookup
}
if config.AuthScheme == "" {
config.AuthScheme = DefaultJWTConfig.AuthScheme
}
if config.KeyFunc == nil {
config.KeyFunc = config.defaultKeyFunc
}
if config.ParseTokenFunc == nil {
config.ParseTokenFunc = config.defaultParseToken
}
extractors, err := createExtractors(config.TokenLookup, config.AuthScheme)
extractors, err := createExtractors(config.TokenLookup)
if err != nil {
panic(err)
return nil, err
}
if len(config.TokenLookupFuncs) > 0 {
extractors = append(config.TokenLookupFuncs, extractors...)
@ -209,17 +149,16 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.BeforeFunc != nil {
config.BeforeFunc(c)
}
var lastExtractorErr error
var lastTokenErr error
for _, extractor := range extractors {
auths, err := extractor(c)
if err != nil {
lastExtractorErr = ErrJWTMissing // backwards compatibility: all extraction errors are same (unlike KeyAuth)
auths, extrErr := extractor(c)
if extrErr != nil {
lastExtractorErr = extrErr
continue
}
for _, auth := range auths {
token, err := config.ParseTokenFunc(auth, c)
token, err := config.ParseTokenFunc(c, auth)
if err != nil {
lastTokenErr = err
continue
@ -232,69 +171,23 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
return next(c)
}
}
// we are here only when we did not successfully extract or parse any of the tokens
// prioritize token errors over extracting errors
err := lastTokenErr
if err == nil { // prioritize token errors over extracting errors
if err == nil {
err = lastExtractorErr
}
if config.ErrorHandler != nil {
return config.ErrorHandler(err)
}
if config.ErrorHandlerWithContext != nil {
tmpErr := config.ErrorHandlerWithContext(err, c)
tmpErr := config.ErrorHandler(c, err)
if config.ContinueOnIgnoredError && tmpErr == nil {
return next(c)
}
return tmpErr
}
// backwards compatible errors codes
if lastTokenErr != nil {
return &echo.HTTPError{
Code: ErrJWTInvalid.Code,
Message: ErrJWTInvalid.Message,
Internal: err,
if lastTokenErr == nil {
return ErrJWTMissing.WithInternal(err)
}
return ErrJWTInvalid.WithInternal(err)
}
return err // this is lastExtractorErr value
}
}
}
func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) {
token := new(jwt.Token)
var err error
// Issue #647, #656
if _, ok := config.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, config.KeyFunc)
} else {
t := reflect.ValueOf(config.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
}
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
// defaultKeyFunc returns a signing key of the given token.
func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(config.SigningKeys) > 0 {
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := config.SigningKeys[kid]; ok {
return key, nil
}
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}
return config.SigningKey, nil
}, nil
}

View File

@ -0,0 +1,76 @@
package middleware_test
import (
"errors"
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"net/http"
"net/http/httptest"
)
// CreateJWTGoParseTokenFunc creates JWTGo implementation for ParseTokenFunc
//
// signingKey is signing key to validate token.
// This is one of the options to provide a token validation key.
// The order of precedence is a user-defined SigningKeys and SigningKey.
// Required if signingKeys is not provided.
//
// signingKeys is Map of signing keys to validate token with kid field usage.
// This is one of the options to provide a token validation key.
// The order of precedence is a user-defined SigningKeys and SigningKey.
// Required if signingKey is not provided
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) {
// keyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != middleware.AlgorithmHS256 {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(signingKeys) == 0 {
return signingKey, nil
}
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := signingKeys[kid]; ok {
return key, nil
}
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}
return func(c echo.Context, auth string) (interface{}, error) {
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
}
func ExampleJWTConfig_withJWTGoAsTokenParser() {
mw := middleware.JWTWithConfig(middleware.JWTConfig{
ParseTokenFunc: CreateJWTGoParseTokenFunc([]byte("secret"), nil),
})
e := echo.New()
e.Use(mw)
e.GET("/", func(c echo.Context) error {
user := c.Get("user").(*jwt.Token)
return c.JSON(http.StatusTeapot, user.Claims)
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
fmt.Printf("status: %v, body: %v", res.Code, res.Body.String())
// Output: status: 418, body: {"admin":true,"name":"John Doe","sub":"1234567890"}
}

View File

@ -1,6 +1,3 @@
//go:build go1.15
// +build go1.15
package middleware
import (
@ -12,11 +9,32 @@ import (
"strings"
"testing"
"github.com/golang-jwt/jwt"
"github.com/labstack/echo/v4"
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) {
// This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != signingMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
return signingKey, nil
}
return func(c echo.Context, auth string) (interface{}, error) {
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc)
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
}
// jwtCustomInfo defines some custom types we're going to use within our tokens.
type jwtCustomInfo struct {
Name string `json:"name"`
@ -25,7 +43,7 @@ type jwtCustomInfo struct {
// jwtCustomClaims are custom claims expanding default ones.
type jwtCustomClaims struct {
*jwt.StandardClaims
*jwt.RegisteredClaims
jwtCustomInfo
}
@ -37,7 +55,7 @@ func TestJWT(t *testing.T) {
return c.JSON(http.StatusOK, token.Claims)
})
e.Use(JWT([]byte("secret")))
e.Use(JWT(createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret"))))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
@ -49,112 +67,96 @@ func TestJWT(t *testing.T) {
assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String())
}
func TestJWTRace(t *testing.T) {
func TestJWT_combinations(t *testing.T) {
e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss"
validKey := []byte("secret")
h := JWTWithConfig(JWTConfig{
Claims: &jwtCustomClaims{},
SigningKey: validKey,
})(handler)
makeReq := func(token string) echo.Context {
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token)
c := e.NewContext(req, res)
assert.NoError(t, h(c))
return c
}
c := makeReq(initialToken)
user := c.Get("user").(*jwt.Token)
claims := user.Claims.(*jwtCustomClaims)
assert.Equal(t, claims.Name, "John Doe")
makeReq(raceToken)
user = c.Get("user").(*jwt.Token)
claims = user.Claims.(*jwtCustomClaims)
// Initial context should still be "John Doe", not "Race Condition"
assert.Equal(t, claims.Name, "John Doe")
assert.Equal(t, claims.Admin, true)
}
func TestJWTConfig(t *testing.T) {
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
validKey := []byte("secret")
invalidKey := []byte("invalid-key")
validAuth := DefaultJWTConfig.AuthScheme + " " + token
validAuth := "Bearer " + token
testCases := []struct {
var testCases = []struct {
name string
expPanic bool
expErrCode int // 0 for Success
config JWTConfig
reqURL string // "/" if empty
hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
expectPanic bool
expectToMiddlewareError string
expectError string
}{
{
name: "No signing key provided",
expPanic: true,
expectToMiddlewareError: "echo jwt middleware requires parse token function",
},
{
name: "invalid TokenLookup",
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey),
TokenLookup: "q",
},
expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q",
},
{
name: "Unexpected signing method",
expErrCode: http.StatusBadRequest,
hdrAuth: validAuth,
config: JWTConfig{
SigningKey: validKey,
SigningMethod: "RS256",
ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey),
},
expectError: "code=401, message=invalid or expired jwt, internal=unexpected jwt signing method=HS256",
},
{
name: "Invalid key",
expErrCode: http.StatusUnauthorized,
hdrAuth: validAuth,
config: JWTConfig{SigningKey: invalidKey},
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, invalidKey),
},
expectError: "code=401, message=invalid or expired jwt, internal=signature is invalid",
},
{
name: "Valid JWT",
hdrAuth: validAuth,
config: JWTConfig{SigningKey: validKey},
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
},
{
name: "Valid JWT with custom AuthScheme",
hdrAuth: "Token" + " " + token,
config: JWTConfig{AuthScheme: "Token", SigningKey: validKey},
config: JWTConfig{
TokenLookup: "header:" + echo.HeaderAuthorization + ":Token ",
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
},
{
name: "Valid JWT with custom claims",
hdrAuth: validAuth,
config: JWTConfig{
Claims: &jwtCustomClaims{},
SigningKey: []byte("secret"),
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
},
},
{
name: "Invalid Authorization header",
hdrAuth: "invalid-auth",
expErrCode: http.StatusBadRequest,
config: JWTConfig{SigningKey: validKey},
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
expectError: "code=401, message=missing or malformed jwt, internal=invalid value in request header",
},
{
name: "Empty header auth field",
config: JWTConfig{SigningKey: validKey},
expErrCode: http.StatusBadRequest,
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
expectError: "code=401, message=missing or malformed jwt, internal=invalid value in request header",
},
{
name: "Valid query method",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt",
},
reqURL: "/?a=b&jwt=" + token,
@ -162,75 +164,75 @@ func TestJWTConfig(t *testing.T) {
{
name: "Invalid query param name",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt",
},
reqURL: "/?a=b&jwtxyz=" + token,
expErrCode: http.StatusBadRequest,
expectError: "code=401, message=missing or malformed jwt, internal=missing value in the query string",
},
{
name: "Invalid query param value",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt",
},
reqURL: "/?a=b&jwt=invalid-token",
expErrCode: http.StatusUnauthorized,
expectError: "code=401, message=invalid or expired jwt, internal=token contains an invalid number of segments",
},
{
name: "Empty query",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt",
},
reqURL: "/?a=b",
expErrCode: http.StatusBadRequest,
expectError: "code=401, message=missing or malformed jwt, internal=missing value in the query string",
},
{
name: "Valid param method",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "param:jwt",
},
reqURL: "/" + token,
name: "Valid param method",
},
{
name: "Valid cookie method",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt",
},
hdrCookie: "jwt=" + token,
name: "Valid cookie method",
},
{
name: "Multiple jwt lookuop",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt,cookie:jwt",
},
hdrCookie: "jwt=" + token,
name: "Multiple jwt lookuop",
},
{
name: "Invalid token with cookie method",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt",
},
expErrCode: http.StatusUnauthorized,
hdrCookie: "jwt=invalid",
expectError: "code=401, message=invalid or expired jwt, internal=token contains an invalid number of segments",
},
{
name: "Empty cookie",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt",
},
expErrCode: http.StatusBadRequest,
expectError: "code=401, message=missing or malformed jwt, internal=missing value in cookies",
},
{
name: "Valid form method",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt",
},
formValues: map[string]string{"jwt": token},
@ -238,58 +240,24 @@ func TestJWTConfig(t *testing.T) {
{
name: "Invalid token with form method",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt",
},
expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"},
expectError: "code=401, message=invalid or expired jwt, internal=token contains an invalid number of segments",
},
{
name: "Empty form field",
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt",
},
expErrCode: http.StatusBadRequest,
},
{
name: "Valid JWT with a valid key using a user-defined KeyFunc",
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return validKey, nil
},
},
},
{
name: "Valid JWT with an invalid key using a user-defined KeyFunc",
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return invalidKey, nil
},
},
expErrCode: http.StatusUnauthorized,
},
{
name: "Token verification does not pass using a user-defined KeyFunc",
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return nil, errors.New("faulty KeyFunc")
},
},
expErrCode: http.StatusUnauthorized,
},
{
name: "Valid JWT with lower case AuthScheme",
hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token,
config: JWTConfig{SigningKey: validKey},
expectError: "code=401, message=missing or malformed jwt, internal=missing value in the form",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
if tc.reqURL == "" {
tc.reqURL = "/"
}
@ -312,128 +280,36 @@ func TestJWTConfig(t *testing.T) {
c := e.NewContext(req, res)
if tc.reqURL == "/"+token {
c.SetParamNames("jwt")
c.SetParamValues(token)
}
if tc.expPanic {
assert.Panics(t, func() {
JWTWithConfig(tc.config)
}, tc.name)
return
}
if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
assert.Equal(t, tc.expErrCode, he.Code, tc.name)
return
}
h := JWTWithConfig(tc.config)(handler)
if assert.NoError(t, h(c), tc.name) {
user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) {
case jwt.MapClaims:
assert.Equal(t, claims["name"], "John Doe", tc.name)
case *jwtCustomClaims:
assert.Equal(t, claims.Name, "John Doe", tc.name)
assert.Equal(t, claims.Admin, true, tc.name)
default:
panic("unexpected type of claims")
}
}
cc := c.(echo.ServableContext)
cc.SetPathParams(echo.PathParams{
{Name: "jwt", Value: token},
})
}
mw, err := tc.config.ToMiddleware()
if tc.expectToMiddlewareError != "" {
assert.EqualError(t, err, tc.expectToMiddlewareError)
return
}
func TestJWTwithKID(t *testing.T) {
test := assert.New(t)
e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
hErr := mw(handler)(c)
if tc.expectError != "" {
assert.EqualError(t, hErr, tc.expectError)
return
}
firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk"
secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM"
wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90"
staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o"
validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")}
invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")}
staticSecret := []byte("static_secret")
invalidStaticSecret := []byte("invalid_secret")
assert.NoError(t, hErr)
for _, tc := range []struct {
expErrCode int // 0 for Success
config JWTConfig
hdrAuth string
info string
}{
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
config: JWTConfig{SigningKeys: validKeys},
info: "First token valid",
},
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
config: JWTConfig{SigningKeys: validKeys},
info: "Second token valid",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken,
config: JWTConfig{SigningKeys: validKeys},
info: "Wrong key id token",
},
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
config: JWTConfig{SigningKey: staticSecret},
info: "Valid static secret token",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
config: JWTConfig{SigningKey: invalidStaticSecret},
info: "Invalid static secret",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
config: JWTConfig{SigningKeys: invalidKeys},
info: "Invalid keys first token",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
config: JWTConfig{SigningKeys: invalidKeys},
info: "Invalid keys second token",
},
} {
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
c := e.NewContext(req, res)
if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
test.Equal(tc.expErrCode, he.Code, tc.info)
continue
}
h := JWTWithConfig(tc.config)(handler)
if test.NoError(h(c), tc.info) {
user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) {
case jwt.MapClaims:
test.Equal(claims["name"], "John Doe", tc.info)
assert.Equal(t, claims["name"], "John Doe")
case *jwtCustomClaims:
test.Equal(claims.Name, "John Doe", tc.info)
test.Equal(claims.Admin, true, tc.info)
assert.Equal(t, claims.Name, "John Doe")
assert.Equal(t, claims.Admin, true)
default:
panic("unexpected type of claims")
}
}
})
}
}
@ -444,7 +320,7 @@ func TestJWTConfig_skipper(t *testing.T) {
Skipper: func(context echo.Context) bool {
return true // skip everything
},
SigningKey: []byte("secret"),
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
}))
isCalled := false
@ -472,11 +348,11 @@ func TestJWTConfig_BeforeFunc(t *testing.T) {
BeforeFunc: func(context echo.Context) {
isCalled = true
},
SigningKey: []byte("secret"),
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
@ -493,18 +369,8 @@ func TestJWTConfig_extractorErrorHandling(t *testing.T) {
{
name: "ok, ErrorHandler is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandler: func(err error) error {
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
},
},
expectStatusCode: http.StatusTeapot,
},
{
name: "ok, ErrorHandlerWithContext is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, context echo.Context) error {
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
ErrorHandler: func(c echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
},
},
@ -539,23 +405,13 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
{
name: "ok, ErrorHandler is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandler: func(err error) error {
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
ErrorHandler: func(c echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error())
},
},
expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n",
},
{
name: "ok, ErrorHandlerWithContext is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, context echo.Context) error {
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error())
},
},
expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n",
},
}
for _, tc := range testCases {
@ -568,14 +424,14 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
config := tc.given
parseTokenCalled := false
config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) {
config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) {
parseTokenCalled = true
return nil, errors.New("parsing failed")
}
e.Use(JWTWithConfig(config))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
@ -598,7 +454,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
signingKey := []byte("secret")
config := JWTConfig{
ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) {
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != "HS256" {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
@ -621,13 +477,166 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
e.Use(JWTWithConfig(config))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
}
func TestMustJWTWithConfig_SuccessHandler(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
success := c.Get("success").(string)
user := c.Get("user").(string)
return c.String(http.StatusTeapot, fmt.Sprintf("%v:%v", success, user))
})
mw, err := JWTConfig{
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
return auth, nil
},
SuccessHandler: func(c echo.Context) {
c.Set("success", "yes")
},
}.ToMiddleware()
assert.NoError(t, err)
e.Use(mw)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderAuthorization, "Bearer valid_token_base64")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, "yes:valid_token_base64", res.Body.String())
assert.Equal(t, http.StatusTeapot, res.Code)
}
func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) {
var testCases = []struct {
name string
givenContinueOnIgnoredError bool
givenErrorHandler JWTErrorHandlerWithContext
givenTokenLookup string
whenAuthHeaders []string
whenCookies []string
whenParseReturn string
whenParseError error
expectHandlerCalled bool
expect string
expectCode int
}{
{
name: "ok, with valid JWT from auth header",
givenContinueOnIgnoredError: true,
givenErrorHandler: func(c echo.Context, err error) error {
return nil
},
whenAuthHeaders: []string{"Bearer valid_token_base64"},
whenParseReturn: "valid_token",
expectCode: http.StatusTeapot,
expect: "valid_token",
},
{
name: "ok, missing header, callNext and set public_token from error handler",
givenContinueOnIgnoredError: true,
givenErrorHandler: func(c echo.Context, err error) error {
if errors.Is(err, &ValueExtractorError{}) {
panic("must get ErrJWTMissing")
}
c.Set("user", "public_token")
return nil
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusTeapot,
expect: "public_token",
},
{
name: "ok, invalid token, callNext and set public_token from error handler",
givenContinueOnIgnoredError: true,
givenErrorHandler: func(c echo.Context, err error) error {
// this is probably not realistic usecase. on parse error you probably want to return error
if err.Error() != "parser_error" {
panic("must get parser_error")
}
c.Set("user", "public_token")
return nil
},
whenAuthHeaders: []string{"Bearer invalid_header"},
whenParseError: errors.New("parser_error"),
expectCode: http.StatusTeapot,
expect: "public_token",
},
{
name: "nok, invalid token, return error from error handler",
givenContinueOnIgnoredError: true,
givenErrorHandler: func(c echo.Context, err error) error {
if err.Error() != "parser_error" {
panic("must get parser_error")
}
return err
},
whenAuthHeaders: []string{"Bearer invalid_header"},
whenParseError: errors.New("parser_error"),
expectCode: http.StatusInternalServerError,
expect: "{\"message\":\"Internal Server Error\"}\n",
},
{
name: "nok, ContinueOnIgnoredError but return error from error handler",
givenContinueOnIgnoredError: true,
givenErrorHandler: func(c echo.Context, err error) error {
return echo.ErrUnauthorized.WithInternal(err)
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusUnauthorized,
expect: "{\"message\":\"Unauthorized\"}\n",
},
{
name: "nok, ContinueOnIgnoredError=false",
givenContinueOnIgnoredError: false,
givenErrorHandler: func(c echo.Context, err error) error {
return echo.ErrUnauthorized.WithInternal(err)
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusUnauthorized,
expect: "{\"message\":\"Unauthorized\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
token := c.Get("user").(string)
return c.String(http.StatusTeapot, token)
})
mw, err := JWTConfig{
ContinueOnIgnoredError: tc.givenContinueOnIgnoredError,
TokenLookup: tc.givenTokenLookup,
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
return tc.whenParseReturn, tc.whenParseError
},
ErrorHandler: tc.givenErrorHandler,
}.ToMiddleware()
assert.NoError(t, err)
e.Use(mw)
req := httptest.NewRequest(http.MethodGet, "/", nil)
for _, a := range tc.whenAuthHeaders {
req.Header.Add(echo.HeaderAuthorization, a)
}
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expect, res.Body.String())
assert.Equal(t, tc.expectCode, res.Code)
})
}
}
func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
e := echo.New()
@ -637,12 +646,12 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
})
e.Use(JWTWithConfig(JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
TokenLookupFuncs: []ValuesExtractor{
func(c echo.Context) ([]string, error) {
return []string{c.Request().Header.Get("X-API-Key")}, nil
},
},
SigningKey: []byte("secret"),
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
@ -653,127 +662,3 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String())
}
func TestJWTConfig_SuccessHandler(t *testing.T) {
var testCases = []struct {
name string
givenToken string
expectCalled bool
expectStatus int
}{
{
name: "ok, success handler is called",
givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ",
expectCalled: true,
expectStatus: http.StatusOK,
},
{
name: "nok, success handler is not called",
givenToken: "x.x.x",
expectCalled: false,
expectStatus: http.StatusUnauthorized,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
token := c.Get("user").(*jwt.Token)
return c.JSON(http.StatusOK, token.Claims)
})
wasCalled := false
e.Use(JWTWithConfig(JWTConfig{
SuccessHandler: func(c echo.Context) {
wasCalled = true
},
SigningKey: []byte("secret"),
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expectCalled, wasCalled)
assert.Equal(t, tc.expectStatus, res.Code)
})
}
}
func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) {
var testCases = []struct {
name string
whenContinueOnIgnoredError bool
givenToken string
expectStatus int
expectBody string
}{
{
name: "no error handler is called",
whenContinueOnIgnoredError: true,
givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ",
expectStatus: http.StatusTeapot,
expectBody: "",
},
{
name: "ContinueOnIgnoredError is false and error handler is called for missing token",
whenContinueOnIgnoredError: false,
givenToken: "",
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
expectStatus: http.StatusOK,
expectBody: "",
},
{
name: "error handler is called for missing token",
whenContinueOnIgnoredError: true,
givenToken: "",
expectStatus: http.StatusTeapot,
expectBody: "public-token",
},
{
name: "error handler is called for invalid token",
whenContinueOnIgnoredError: true,
givenToken: "x.x.x",
expectStatus: http.StatusUnauthorized,
expectBody: "{\"message\":\"Unauthorized\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
testValue, _ := c.Get("test").(string)
return c.String(http.StatusTeapot, testValue)
})
e.Use(JWTWithConfig(JWTConfig{
ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, c echo.Context) error {
if err == ErrJWTMissing {
c.Set("test", "public-token")
return nil
}
return echo.ErrUnauthorized
},
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenToken != "" {
req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken)
}
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expectStatus, res.Code)
assert.Equal(t, tc.expectBody, res.Body.String())
})
}
}

View File

@ -2,13 +2,13 @@ package middleware
import (
"errors"
"github.com/labstack/echo/v4"
"fmt"
"github.com/labstack/echo/v5"
"net/http"
)
type (
// KeyAuthConfig defines the config for KeyAuth middleware.
KeyAuthConfig struct {
type KeyAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -28,16 +28,17 @@ type (
// - "header:Authorization,header:X-Api-Key"
KeyLookup string
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// Validator is a function to validate key.
// Required.
Validator KeyAuthValidator
// ErrorHandler defines a function which is executed for an invalid key.
// ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
// function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
// It may be used to define a custom error.
//
// Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain.
ErrorHandler KeyAuthErrorHandler
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
@ -49,34 +50,21 @@ type (
}
// KeyAuthValidator defines a function to validate KeyAuth credentials.
KeyAuthValidator func(auth string, c echo.Context) (bool, error)
type KeyAuthValidator func(c echo.Context, key string) (bool, error)
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
KeyAuthErrorHandler func(err error, c echo.Context) error
)
type KeyAuthErrorHandler func(c echo.Context, err error) error
// ErrKeyMissing denotes an error raised when key value could not be extracted from request
var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
// ErrInvalidKey denotes an error raised when key value is invalid by validator
var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
var (
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
DefaultKeyAuthConfig = KeyAuthConfig{
var DefaultKeyAuthConfig = KeyAuthConfig{
Skipper: DefaultSkipper,
KeyLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
}
)
// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups
type ErrKeyAuthMissing struct {
Err error
}
// Error returns errors text
func (e *ErrKeyAuthMissing) Error() string {
return e.Err.Error()
}
// Unwrap unwraps error
func (e *ErrKeyAuthMissing) Unwrap() error {
return e.Err
KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
}
// KeyAuth returns an KeyAuth middleware.
@ -90,27 +78,33 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
return KeyAuthWithConfig(c)
}
// KeyAuthWithConfig returns an KeyAuth middleware with config.
// See `KeyAuth()`.
// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid.
//
// For first valid key it calls the next handler.
// For invalid key, it sends "401 - Unauthorized" response.
// For missing key, it sends "400 - Bad Request" response.
func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration
func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultKeyAuthConfig.Skipper
}
// Defaults
if config.AuthScheme == "" {
config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
}
if config.KeyLookup == "" {
config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
}
if config.Validator == nil {
panic("echo: key-auth middleware requires a validator function")
return nil, errors.New("echo key-auth middleware requires a validator function")
}
extractors, err := createExtractors(config.KeyLookup, config.AuthScheme)
extractors, err := createExtractors(config.KeyLookup)
if err != nil {
panic(err)
return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", err)
}
if len(extractors) == 0 {
return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -122,59 +116,41 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
var lastExtractorErr error
var lastValidatorErr error
for _, extractor := range extractors {
keys, err := extractor(c)
if err != nil {
lastExtractorErr = err
keys, extrErr := extractor(c)
if extrErr != nil {
lastExtractorErr = extrErr
continue
}
for _, key := range keys {
valid, err := config.Validator(key, c)
valid, err := config.Validator(c, key)
if err != nil {
lastValidatorErr = err
continue
}
if valid {
if !valid {
lastValidatorErr = ErrInvalidKey
continue
}
return next(c)
}
lastValidatorErr = errors.New("invalid key")
}
}
// we are here only when we did not successfully extract and validate any of keys
// prioritize validator errors over extracting errors
err := lastValidatorErr
if err == nil { // prioritize validator errors over extracting errors
// ugly part to preserve backwards compatible errors. someone could rely on them
if lastExtractorErr == errQueryExtractorValueMissing {
err = errors.New("missing key in the query string")
} else if lastExtractorErr == errCookieExtractorValueMissing {
err = errors.New("missing key in cookies")
} else if lastExtractorErr == errFormExtractorValueMissing {
err = errors.New("missing key in the form")
} else if lastExtractorErr == errHeaderExtractorValueMissing {
err = errors.New("missing key in request header")
} else if lastExtractorErr == errHeaderExtractorValueInvalid {
err = errors.New("invalid key in the request header")
} else {
if err == nil {
err = lastExtractorErr
}
err = &ErrKeyAuthMissing{Err: err}
}
if config.ErrorHandler != nil {
tmpErr := config.ErrorHandler(err, c)
tmpErr := config.ErrorHandler(c, err)
if config.ContinueOnIgnoredError && tmpErr == nil {
return next(c)
}
return tmpErr
}
if lastValidatorErr != nil { // prioritize validator errors over extracting errors
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "Unauthorized",
Internal: lastValidatorErr,
}
}
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
if lastValidatorErr == nil {
return ErrKeyMissing.WithInternal(err)
}
return echo.ErrUnauthorized.WithInternal(err)
}
}, nil
}

View File

@ -7,11 +7,11 @@ import (
"strings"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func testKeyValidator(key string, c echo.Context) (bool, error) {
func testKeyValidator(c echo.Context, key string) (bool, error) {
switch key {
case "valid-key":
return true, nil
@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
},
expectHandlerCalled: false,
expectError: "code=401, message=Unauthorized, internal=invalid key",
expectError: "code=401, message=Unauthorized, internal=code=401, message=invalid key",
},
{
name: "nok, defaults, invalid scheme in header",
@ -84,24 +84,13 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
},
expectHandlerCalled: false,
expectError: "code=400, message=invalid key in the request header",
expectError: "code=401, message=missing key, internal=invalid value in request header",
},
{
name: "nok, defaults, missing header",
givenRequest: func(req *http.Request) {},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header",
},
{
name: "ok, custom key lookup from multiple places, query and header",
givenRequest: func(req *http.Request) {
req.URL.RawQuery = "key=invalid-key"
req.Header.Set("API-Key", "valid-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "query:key,header:API-Key"
},
expectHandlerCalled: true,
expectError: "code=401, message=missing key, internal=missing value in request header",
},
{
name: "ok, custom key lookup, header",
@ -121,7 +110,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "header:API-Key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header",
expectError: "code=401, message=missing key, internal=missing value in request header",
},
{
name: "ok, custom key lookup, query",
@ -141,7 +130,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "query:key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in the query string",
expectError: "code=401, message=missing key, internal=missing value in the query string",
},
{
name: "ok, custom key lookup, form",
@ -166,7 +155,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "form:key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in the form",
expectError: "code=401, message=missing key, internal=missing value in the form",
},
{
name: "ok, custom key lookup, cookie",
@ -190,20 +179,20 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "cookie:key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in cookies",
expectError: "code=401, message=missing key, internal=missing value in cookies",
},
{
name: "nok, custom errorHandler, error from extractor",
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:token"
conf.ErrorHandler = func(err error, context echo.Context) error {
conf.ErrorHandler = func(c echo.Context, err error) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err
return httpError
}
},
expectHandlerCalled: false,
expectError: "code=418, message=custom, internal=missing key in request header",
expectError: "code=418, message=custom, internal=missing value in request header",
},
{
name: "nok, custom errorHandler, error from validator",
@ -211,7 +200,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.ErrorHandler = func(err error, context echo.Context) error {
conf.ErrorHandler = func(c echo.Context, err error) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err
return httpError
@ -269,108 +258,96 @@ func TestKeyAuthWithConfig(t *testing.T) {
}
}
func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) {
assert.PanicsWithError(
t,
"extractor source for lookup could not be split into needed parts: a",
func() {
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
KeyAuthWithConfig(KeyAuthConfig{
Validator: testKeyValidator,
KeyLookup: "a",
})(handler)
},
)
}
func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) {
assert.PanicsWithValue(
t,
"echo: key-auth middleware requires a validator function",
func() {
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
KeyAuthWithConfig(KeyAuthConfig{
Validator: nil,
})(handler)
},
)
}
func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) {
func TestKeyAuthWithConfig_errors(t *testing.T) {
var testCases = []struct {
name string
whenContinueOnIgnoredError bool
givenKey string
expectStatus int
expectBody string
whenConfig KeyAuthConfig
expectError string
}{
{
name: "no error handler is called",
whenContinueOnIgnoredError: true,
givenKey: "valid-key",
expectStatus: http.StatusTeapot,
expectBody: "",
name: "ok, no error",
whenConfig: KeyAuthConfig{
Validator: func(c echo.Context, key string) (bool, error) {
return false, nil
},
},
},
{
name: "ContinueOnIgnoredError is false and error handler is called for missing token",
whenContinueOnIgnoredError: false,
givenKey: "",
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
expectStatus: http.StatusOK,
expectBody: "",
name: "ok, missing validator func",
whenConfig: KeyAuthConfig{
Validator: nil,
},
expectError: "echo key-auth middleware requires a validator function",
},
{
name: "error handler is called for missing token",
whenContinueOnIgnoredError: true,
givenKey: "",
expectStatus: http.StatusTeapot,
expectBody: "public-auth",
name: "ok, extractor source can not be split",
whenConfig: KeyAuthConfig{
KeyLookup: "nope",
Validator: func(c echo.Context, key string) (bool, error) {
return false, nil
},
},
expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope",
},
{
name: "error handler is called for invalid token",
whenContinueOnIgnoredError: true,
givenKey: "x.x.x",
expectStatus: http.StatusUnauthorized,
expectBody: "{\"message\":\"Unauthorized\"}\n",
name: "ok, no extractors",
whenConfig: KeyAuthConfig{
KeyLookup: "nope:nope",
Validator: func(c echo.Context, key string) (bool, error) {
return false, nil
},
},
expectError: "echo key-auth middleware could not create extractors from KeyLookup string",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
testValue, _ := c.Get("test").(string)
return c.String(http.StatusTeapot, testValue)
mw, err := tc.whenConfig.ToMiddleware()
if tc.expectError != "" {
assert.Nil(t, mw)
assert.EqualError(t, err, tc.expectError)
} else {
assert.NotNil(t, mw)
assert.NoError(t, err)
}
})
}
}
e.Use(KeyAuthWithConfig(KeyAuthConfig{
func TestMustKeyAuthWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
KeyAuthWithConfig(KeyAuthConfig{})
})
}
func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) {
handlerCalled := false
var authValue string
handler := func(c echo.Context) error {
handlerCalled = true
authValue = c.Get("auth").(string)
return c.String(http.StatusOK, "test")
}
middlewareChain := KeyAuthWithConfig(KeyAuthConfig{
Validator: testKeyValidator,
ErrorHandler: func(err error, c echo.Context) error {
if _, ok := err.(*ErrKeyAuthMissing); ok {
c.Set("test", "public-auth")
ErrorHandler: func(c echo.Context, err error) error {
// could check error to decide if we can swallow the error
c.Set("auth", "public")
return nil
}
return echo.ErrUnauthorized
},
KeyLookup: "header:X-API-Key",
ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
}))
ContinueOnIgnoredError: true,
})(handler)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenKey != "" {
req.Header.Set("X-API-Key", tc.givenKey)
}
res := httptest.NewRecorder()
// no auth header this time
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(res, req)
err := middlewareChain(c)
assert.Equal(t, tc.expectStatus, res.Code)
assert.Equal(t, tc.expectBody, res.Body.String())
})
}
assert.NoError(t, err)
assert.True(t, handlerCalled)
assert.Equal(t, "public", authValue)
}

View File

@ -3,20 +3,19 @@ package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"strconv"
"strings"
"sync"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/color"
"github.com/labstack/echo/v5"
"github.com/valyala/fasttemplate"
)
type (
// LoggerConfig defines the config for Logger middleware.
LoggerConfig struct {
type LoggerConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -49,42 +48,41 @@ type (
// Example "${remote_ip} ${status}"
//
// Optional. Default value DefaultLoggerConfig.Format.
Format string `yaml:"format"`
Format string
// Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
CustomTimeFormat string `yaml:"custom_time_format"`
CustomTimeFormat string
// Output is a writer where logs in JSON format are written.
// Optional. Default value os.Stdout.
// Optional. Default destination `echo.Logger.Infof()`
Output io.Writer
template *fasttemplate.Template
colorer *color.Color
pool *sync.Pool
}
)
var (
// DefaultLoggerConfig is the default Logger middleware config.
DefaultLoggerConfig = LoggerConfig{
var DefaultLoggerConfig = LoggerConfig{
Skipper: DefaultSkipper,
Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` +
Format: `{"time":"${time_rfc3339_nano}","level":"INFO","id":"${id}","remote_ip":"${remote_ip}",` +
`"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` +
`"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` +
`,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n",
CustomTimeFormat: "2006-01-02 15:04:05.00000",
colorer: color.New(),
}
)
// Logger returns a middleware that logs HTTP requests.
func Logger() echo.MiddlewareFunc {
return LoggerWithConfig(DefaultLoggerConfig)
}
// LoggerWithConfig returns a Logger middleware with config.
// See: `Logger()`.
// LoggerWithConfig returns a Logger middleware with config or panics on invalid configuration.
func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts LoggerConfig to middleware or returns an error for invalid configuration
func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultLoggerConfig.Skipper
@ -92,13 +90,8 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
if config.Format == "" {
config.Format = DefaultLoggerConfig.Format
}
if config.Output == nil {
config.Output = DefaultLoggerConfig.Output
}
config.template = fasttemplate.New(config.Format, "${", "}")
config.colorer = color.New()
config.colorer.SetOutput(config.Output)
config.pool = &sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 256))
@ -106,23 +99,23 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
start := time.Now()
if err = next(c); err != nil {
c.Error(err)
}
err := next(c)
stop := time.Now()
buf := config.pool.Get().(*bytes.Buffer)
buf.Reset()
defer config.pool.Put(buf)
if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
_, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
switch tag {
case "time_unix":
return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10))
@ -161,17 +154,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
case "user_agent":
return buf.WriteString(req.UserAgent())
case "status":
n := res.Status
s := config.colorer.Green(n)
switch {
case n >= 500:
s = config.colorer.Red(n)
case n >= 400:
s = config.colorer.Yellow(n)
case n >= 300:
s = config.colorer.Cyan(n)
status := res.Status
if err != nil {
if httpErr, ok := err.(*echo.HTTPError); ok {
status = httpErr.Code
}
return buf.WriteString(s)
}
return buf.WriteString(strconv.Itoa(status))
case "error":
if err != nil {
// Error may contain invalid JSON e.g. `"`
@ -201,23 +190,31 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
case strings.HasPrefix(tag, "form:"):
return buf.Write([]byte(c.FormValue(tag[5:])))
case strings.HasPrefix(tag, "cookie:"):
cookie, err := c.Cookie(tag[7:])
if err == nil {
cookie, cookieErr := c.Cookie(tag[7:])
if cookieErr == nil {
return buf.Write([]byte(cookie.Value))
}
}
}
return 0, nil
}); err != nil {
return
})
if tmplErr != nil {
if err != nil {
return fmt.Errorf("error in middleware chain and also failed to create log from template: %v: %w", tmplErr, err)
}
return fmt.Errorf("failed to create log from template: %w", tmplErr)
}
if config.Output == nil {
_, err = c.Logger().Output().Write(buf.Bytes())
return
if config.Output != nil {
if _, lErr := config.Output.Write(buf.Bytes()); lErr != nil {
return lErr
}
_, err = config.Output.Write(buf.Bytes())
return
} else {
if _, lErr := c.Echo().Logger.Write(buf.Bytes()); lErr != nil {
return lErr
}
}
return err
}
}, nil
}

View File

@ -12,7 +12,7 @@ import (
"time"
"unsafe"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@ -61,7 +61,7 @@ func TestLoggerIPAddress(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
buf := new(bytes.Buffer)
e.Logger.SetOutput(buf)
e.Logger = &testLogger{output: buf}
ip := "127.0.0.1"
h := Logger()(func(c echo.Context) error {
return c.String(http.StatusOK, "test")

View File

@ -3,12 +3,11 @@ package middleware
import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
)
type (
// MethodOverrideConfig defines the config for MethodOverride middleware.
MethodOverrideConfig struct {
type MethodOverrideConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -18,16 +17,13 @@ type (
}
// MethodOverrideGetter is a function that gets overridden method from the request
MethodOverrideGetter func(echo.Context) string
)
type MethodOverrideGetter func(echo.Context) string
var (
// DefaultMethodOverrideConfig is the default MethodOverride middleware config.
DefaultMethodOverrideConfig = MethodOverrideConfig{
var DefaultMethodOverrideConfig = MethodOverrideConfig{
Skipper: DefaultSkipper,
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
}
)
// MethodOverride returns a MethodOverride middleware.
// MethodOverride middleware checks for the overridden method from the request and
@ -38,9 +34,13 @@ func MethodOverride() echo.MiddlewareFunc {
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
}
// MethodOverrideWithConfig returns a MethodOverride middleware with config.
// See: `MethodOverride()`.
// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration.
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration
func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultMethodOverrideConfig.Skipper
@ -64,7 +64,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
}
return next(c)
}
}
}, nil
}
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from

View File

@ -6,7 +6,7 @@ import (
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@ -22,28 +22,70 @@ func TestMethodOverride(t *testing.T) {
rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
c := e.NewContext(req, rec)
m(h)(c)
err := m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_formParam(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Override with form parameter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")})
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
rec = httptest.NewRecorder()
m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware()
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
c = e.NewContext(req, rec)
m(h)(c)
c := e.NewContext(req, rec)
err = m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_queryParam(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Override with query parameter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")})
req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
m(h)(c)
m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware()
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err = m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_ignoreGet(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Ignore `GET`
req = httptest.NewRequest(http.MethodGet, "/", nil)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodGet, req.Method)
}

View File

@ -6,17 +6,14 @@ import (
"strconv"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
)
type (
// Skipper defines a function to skip middleware. Returning true skips processing
// the middleware.
Skipper func(c echo.Context) bool
// Skipper defines a function to skip middleware. Returning true skips processing the middleware.
type Skipper func(c echo.Context) bool
// BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc func(c echo.Context)
)
type BeforeFunc func(c echo.Context)
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
groups := pattern.FindAllStringSubmatch(input, -1)
@ -87,3 +84,11 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error
func DefaultSkipper(echo.Context) bool {
return false
}
func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc {
mw, err := config.ToMiddleware()
if err != nil {
panic(err)
}
return mw
}

View File

@ -2,6 +2,7 @@ package middleware
import (
"context"
"errors"
"fmt"
"io"
"math/rand"
@ -15,14 +16,13 @@ import (
"sync/atomic"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
)
// TODO: Handle TLS proxy
type (
// ProxyConfig defines the config for Proxy middleware.
ProxyConfig struct {
type ProxyConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -59,46 +59,43 @@ type (
}
// ProxyTarget defines the upstream target.
ProxyTarget struct {
type ProxyTarget struct {
Name string
URL *url.URL
Meta echo.Map
}
// ProxyBalancer defines an interface to implement a load balancing technique.
ProxyBalancer interface {
type ProxyBalancer interface {
AddTarget(*ProxyTarget) bool
RemoveTarget(string) bool
Next(echo.Context) *ProxyTarget
}
commonBalancer struct {
type commonBalancer struct {
targets []*ProxyTarget
mutex sync.RWMutex
}
// RandomBalancer implements a random load balancing technique.
randomBalancer struct {
type randomBalancer struct {
*commonBalancer
random *rand.Rand
}
// RoundRobinBalancer implements a round-robin load balancing technique.
roundRobinBalancer struct {
type roundRobinBalancer struct {
*commonBalancer
i uint32
}
)
var (
// DefaultProxyConfig is the default Proxy middleware config.
DefaultProxyConfig = ProxyConfig{
var DefaultProxyConfig = ProxyConfig{
Skipper: DefaultSkipper,
ContextKey: "target",
}
)
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
in, _, err := c.Response().Hijack()
if err != nil {
@ -203,15 +200,23 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
return ProxyWithConfig(c)
}
// ProxyWithConfig returns a Proxy middleware with config.
// See: `Proxy()`
// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid.
//
// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration
func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultProxyConfig.Skipper
}
if config.ContextKey == "" {
config.ContextKey = DefaultProxyConfig.ContextKey
}
if config.Balancer == nil {
panic("echo: proxy middleware requires balancer")
return nil, errors.New("echo proxy middleware requires balancer")
}
if config.Rewrite != nil {
@ -254,10 +259,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req)
proxyRaw(c, tgt).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
default:
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
proxyHTTP(c, tgt, config).ServeHTTP(res, req)
}
if e, ok := c.Get("_error").(error); ok {
err = e
@ -265,7 +270,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
return
}
}
}, nil
}
// StatusCodeContextCanceled is a custom HTTP status code for situations
@ -275,7 +280,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// 499 too instead of the more problematic 5xx, which does not allow to detect this situation
const StatusCodeContextCanceled = 499
func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
func proxyHTTP(c echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler {
proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
desc := tgt.URL.String()

View File

@ -14,7 +14,7 @@ import (
"testing"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@ -55,7 +55,7 @@ func TestProxy(t *testing.T) {
// Random
e := echo.New()
e.Use(Proxy(rb))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
@ -77,7 +77,7 @@ func TestProxy(t *testing.T) {
// Round-robin
rrb := NewRoundRobinBalancer(targets)
e = echo.New()
e.Use(Proxy(rrb))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
@ -113,15 +113,20 @@ func TestProxy(t *testing.T) {
return nil
}
}
rrb1 := NewRoundRobinBalancer(targets)
e = echo.New()
e.Use(contextObserver)
e.Use(Proxy(rrb1))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)}))
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
}
func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) {
assert.Panics(t, func() {
ProxyWithConfig(ProxyConfig{Balancer: nil})
})
}
func TestProxyRealIPHeader(t *testing.T) {
// Setup
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
@ -129,7 +134,7 @@ func TestProxyRealIPHeader(t *testing.T) {
url, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
e := echo.New()
e.Use(Proxy(rrb))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
@ -334,7 +339,7 @@ func TestProxyError(t *testing.T) {
// Random
e := echo.New()
e.Use(Proxy(rb))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
@ -362,7 +367,7 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
rb := NewRandomBalancer(nil)
assert.True(t, rb.AddTarget(target))
e := echo.New()
e.Use(Proxy(rb))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(req.Context())

View File

@ -1,25 +1,22 @@
package middleware
import (
"errors"
"net/http"
"sync"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"golang.org/x/time/rate"
)
type (
// RateLimiterStore is the interface to be implemented by custom stores.
RateLimiterStore interface {
// Stores for the rate limiter have to implement the Allow method
type RateLimiterStore interface {
Allow(identifier string) (bool, error)
}
)
type (
// RateLimiterConfig defines the configuration for the rate limiter
RateLimiterConfig struct {
type RateLimiterConfig struct {
Skipper Skipper
BeforeFunc BeforeFunc
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor
@ -31,17 +28,15 @@ type (
// DenyHandler provides a handler to be called when RateLimiter denies access
DenyHandler func(context echo.Context, identifier string, err error) error
}
// Extractor is used to extract data from echo.Context
Extractor func(context echo.Context) (string, error)
)
// errors
var (
// Extractor is used to extract data from echo.Context
type Extractor func(context echo.Context) (string, error)
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
// ErrExtractorError denotes an error raised when extractor function is unsuccessful
ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
)
var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
// DefaultRateLimiterConfig defines default values for RateLimiterConfig
var DefaultRateLimiterConfig = RateLimiterConfig{
@ -111,6 +106,11 @@ RateLimiterWithConfig returns a rate limiting middleware
}, middleware.RateLimiterWithConfig(config))
*/
func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration
func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultRateLimiterConfig.Skipper
}
@ -124,7 +124,7 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
}
if config.Store == nil {
panic("Store configuration must be provided")
return nil, errors.New("echo rate limiter store configuration must be provided")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
@ -137,36 +137,32 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
identifier, err := config.IdentifierExtractor(c)
if err != nil {
c.Error(config.ErrorHandler(c, err))
return nil
return config.ErrorHandler(c, err)
}
if allow, err := config.Store.Allow(identifier); !allow {
c.Error(config.DenyHandler(c, identifier, err))
return nil
if allow, allowErr := config.Store.Allow(identifier); !allow {
return config.DenyHandler(c, identifier, allowErr)
}
return next(c)
}
}
}, nil
}
type (
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
RateLimiterMemoryStore struct {
type RateLimiterMemoryStore struct {
visitors map[string]*Visitor
mutex sync.Mutex
rate rate.Limit //for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit
burst int
expiresIn time.Duration
lastCleanup time.Time
}
// Visitor signifies a unique user's limiter details
Visitor struct {
type Visitor struct {
*rate.Limiter
lastSeen time.Time
}
)
/*
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with

View File

@ -10,8 +10,7 @@ import (
"testing"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
)
@ -25,19 +24,19 @@ func TestRateLimiter(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
mw := RateLimiter(inMemoryStore)
mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
testCases := []struct {
id string
code int
expectErr string
}{
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
@ -47,20 +46,25 @@ func TestRateLimiter(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
_ = mw(handler)(c)
assert.Equal(t, tc.code, rec.Code)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
func TestRateLimiter_panicBehaviour(t *testing.T) {
func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
assert.Panics(t, func() {
RateLimiter(nil)
RateLimiterWithConfig(RateLimiterConfig{})
})
assert.NotPanics(t, func() {
RateLimiter(inMemoryStore)
RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
})
}
@ -73,7 +77,7 @@ func TestRateLimiterWithConfig(t *testing.T) {
return c.String(http.StatusOK, "test")
}
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
IdentifierExtractor: func(c echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" {
@ -88,7 +92,8 @@ func TestRateLimiterWithConfig(t *testing.T) {
return ctx.JSON(http.StatusBadRequest, nil)
},
Store: inMemoryStore,
})
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
@ -111,8 +116,9 @@ func TestRateLimiterWithConfig(t *testing.T) {
c := e.NewContext(req, rec)
_ = mw(handler)(c)
err := mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, tc.code, rec.Code)
}
}
@ -126,7 +132,7 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
return c.String(http.StatusOK, "test")
}
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
IdentifierExtractor: func(c echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" {
@ -135,19 +141,20 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
return id, nil
},
Store: inMemoryStore,
})
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
code int
expectErr string
}{
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusTooManyRequests},
{"", http.StatusForbidden},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{expectErr: "code=403, message=error while extracting identifier, internal=invalid identifier"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
@ -158,9 +165,13 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
c := e.NewContext(req, rec)
_ = mw(handler)(c)
assert.Equal(t, tc.code, rec.Code)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
@ -174,21 +185,22 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
return c.String(http.StatusOK, "test")
}
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
Store: inMemoryStore,
})
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
code int
expectErr string
}{
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
@ -199,9 +211,13 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
c := e.NewContext(req, rec)
_ = mw(handler)(c)
assert.Equal(t, tc.code, rec.Code)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
}
@ -222,7 +238,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
Skipper: func(c echo.Context) bool {
return true
},
@ -233,10 +249,12 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil
},
})
}.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c)
err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, false, beforeFuncRan)
}
@ -256,7 +274,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
Skipper: func(c echo.Context) bool {
return false
},
@ -267,7 +285,8 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil
},
})
}.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c)
@ -291,7 +310,7 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
BeforeFunc: func(c echo.Context) {
beforeRan = true
},
@ -299,10 +318,12 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil
},
})
}.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c)
err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, true, beforeRan)
}
@ -413,7 +434,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) {
func generateAddressList(count int) []string {
addrs := make([]string, count)
for i := 0; i < count; i++ {
addrs[i] = random.String(15)
addrs[i] = randomString(15)
}
return addrs
}

View File

@ -5,53 +5,35 @@ import (
"net/http"
"runtime"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/labstack/echo/v5"
)
type (
// LogErrorFunc defines a function for custom logging in the middleware.
LogErrorFunc func(c echo.Context, err error, stack []byte) error
// RecoverConfig defines the config for Recover middleware.
RecoverConfig struct {
type RecoverConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Size of the stack to be printed.
// Optional. Default value 4KB.
StackSize int `yaml:"stack_size"`
StackSize int
// DisableStackAll disables formatting stack traces of all other goroutines
// into buffer after the trace for the current goroutine.
// Optional. Default value false.
DisableStackAll bool `yaml:"disable_stack_all"`
DisableStackAll bool
// DisablePrintStack disables printing stack trace.
// Optional. Default value as false.
DisablePrintStack bool `yaml:"disable_print_stack"`
// LogLevel is log level to printing stack trace.
// Optional. Default value 0 (Print).
LogLevel log.Lvl
// LogErrorFunc defines a function for custom logging in the middleware.
// If it's set you don't need to provide LogLevel for config.
LogErrorFunc LogErrorFunc
DisablePrintStack bool
}
)
var (
// DefaultRecoverConfig is the default Recover middleware config.
DefaultRecoverConfig = RecoverConfig{
var DefaultRecoverConfig = RecoverConfig{
Skipper: DefaultSkipper,
StackSize: 4 << 10, // 4 KB
DisableStackAll: false,
DisablePrintStack: false,
LogLevel: 0,
LogErrorFunc: nil,
}
)
// Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
@ -59,9 +41,13 @@ func Recover() echo.MiddlewareFunc {
return RecoverWithConfig(DefaultRecoverConfig)
}
// RecoverWithConfig returns a Recover middleware with config.
// See: `Recover()`.
// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration.
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration
func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultRecoverConfig.Skipper
@ -71,7 +57,7 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
@ -81,42 +67,19 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
if r == http.ErrAbortHandler {
panic(r)
}
err, ok := r.(error)
tmpErr, ok := r.(error)
if !ok {
err = fmt.Errorf("%v", r)
tmpErr = fmt.Errorf("%v", r)
}
var stack []byte
var length int
if !config.DisablePrintStack {
stack = make([]byte, config.StackSize)
length = runtime.Stack(stack, !config.DisableStackAll)
stack = stack[:length]
stack := make([]byte, config.StackSize)
length := runtime.Stack(stack, !config.DisableStackAll)
tmpErr = fmt.Errorf("[PANIC RECOVER] %w %s", tmpErr, stack[:length])
}
if config.LogErrorFunc != nil {
err = config.LogErrorFunc(c, err, stack)
} else if !config.DisablePrintStack {
msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])
switch config.LogLevel {
case log.DEBUG:
c.Logger().Debug(msg)
case log.INFO:
c.Logger().Info(msg)
case log.WARN:
c.Logger().Warn(msg)
case log.ERROR:
c.Logger().Error(msg)
case log.OFF:
// None.
default:
c.Logger().Print(msg)
}
}
c.Error(err)
err = tmpErr
}
}()
return next(c)
}
}
}, nil
}

View File

@ -2,36 +2,57 @@ package middleware
import (
"bytes"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestRecover(t *testing.T) {
e := echo.New()
buf := new(bytes.Buffer)
e.Logger.SetOutput(buf)
e.Logger = &testLogger{output: buf}
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Recover()(echo.HandlerFunc(func(c echo.Context) error {
h := Recover()(func(c echo.Context) error {
panic("test")
}))
h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, buf.String(), "PANIC RECOVER")
})
err := h(c)
assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine")
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
assert.Contains(t, buf.String(), "") // nothing is logged
}
func TestRecover_skipper(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
config := RecoverConfig{
Skipper: func(c echo.Context) bool {
return true
},
}
h := RecoverWithConfig(config)(func(c echo.Context) error {
panic("testPANIC")
})
var err error
assert.Panics(t, func() {
err = h(c)
})
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
}
func TestRecoverErrAbortHandler(t *testing.T) {
e := echo.New()
buf := new(bytes.Buffer)
e.Logger.SetOutput(buf)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@ -51,115 +72,66 @@ func TestRecoverErrAbortHandler(t *testing.T) {
}
}()
h(c)
hErr := h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.NotContains(t, buf.String(), "PANIC RECOVER")
assert.NotContains(t, hErr.Error(), "PANIC RECOVER")
}
func TestRecoverWithConfig_LogLevel(t *testing.T) {
tests := []struct {
logLevel log.Lvl
levelName string
}{{
logLevel: log.DEBUG,
levelName: "DEBUG",
}, {
logLevel: log.INFO,
levelName: "INFO",
}, {
logLevel: log.WARN,
levelName: "WARN",
}, {
logLevel: log.ERROR,
levelName: "ERROR",
}, {
logLevel: log.OFF,
levelName: "OFF",
}}
func TestRecoverWithConfig(t *testing.T) {
var testCases = []struct {
name string
givenNoPanic bool
whenConfig RecoverConfig
expectErrContain string
expectErr string
}{
{
name: "ok, default config",
whenConfig: DefaultRecoverConfig,
expectErrContain: "[PANIC RECOVER] testPANIC goroutine",
},
{
name: "ok, no panic",
givenNoPanic: true,
whenConfig: DefaultRecoverConfig,
expectErrContain: "",
},
{
name: "ok, DisablePrintStack",
whenConfig: RecoverConfig{
DisablePrintStack: true,
},
expectErr: "testPANIC",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.levelName, func(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Logger.SetLevel(log.DEBUG)
buf := new(bytes.Buffer)
e.Logger.SetOutput(buf)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
config := DefaultRecoverConfig
config.LogLevel = tt.logLevel
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
panic("test")
}))
config := tc.whenConfig
h := RecoverWithConfig(config)(func(c echo.Context) error {
if tc.givenNoPanic {
return nil
}
panic("testPANIC")
})
h(c)
err := h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
output := buf.String()
if tt.logLevel == log.OFF {
assert.Empty(t, output)
if tc.expectErrContain != "" {
assert.Contains(t, err.Error(), tc.expectErrContain)
} else if tc.expectErr != "" {
assert.Contains(t, err.Error(), tc.expectErr)
} else {
assert.Contains(t, output, "PANIC RECOVER")
assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName))
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
})
}
}
func TestRecoverWithConfig_LogErrorFunc(t *testing.T) {
e := echo.New()
e.Logger.SetLevel(log.DEBUG)
buf := new(bytes.Buffer)
e.Logger.SetOutput(buf)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
testError := errors.New("test")
config := DefaultRecoverConfig
config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error {
msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack)
if errors.Is(err, testError) {
c.Logger().Debug(msg)
} else {
c.Logger().Error(msg)
}
return err
}
t.Run("first branch case for LogErrorFunc", func(t *testing.T) {
buf.Reset()
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
panic(testError)
}))
h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
output := buf.String()
assert.Contains(t, output, "PANIC RECOVER")
assert.Contains(t, output, `"level":"DEBUG"`)
})
t.Run("else branch case for LogErrorFunc", func(t *testing.T) {
buf.Reset()
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
panic("other")
}))
h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
output := buf.String()
assert.Contains(t, output, "PANIC RECOVER")
assert.Contains(t, output, `"level":"ERROR"`)
})
}

View File

@ -1,10 +1,11 @@
package middleware
import (
"errors"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
)
// RedirectConfig defines the config for Redirect middleware.
@ -14,7 +15,9 @@ type RedirectConfig struct {
// Status code to be used when redirecting the request.
// Optional. Default value http.StatusMovedPermanently.
Code int `yaml:"code"`
Code int
redirect redirectLogic
}
// redirectLogic represents a function that given a scheme, host and uri
@ -24,29 +27,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string)
const www = "www."
// DefaultRedirectConfig is the default Redirect middleware config.
var DefaultRedirectConfig = RedirectConfig{
Skipper: DefaultSkipper,
Code: http.StatusMovedPermanently,
}
// RedirectHTTPSConfig is the HTTPS Redirect middleware config.
var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS}
// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config.
var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW}
// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config.
var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW}
// RedirectWWWConfig is the WWW Redirect middleware config.
var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW}
// RedirectNonWWWConfig is the non WWW Redirect middleware config.
var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW}
// HTTPSRedirect redirects http requests to https.
// For example, http://labstack.com will be redirect to https://labstack.com.
//
// Usage `Echo#Pre(HTTPSRedirect())`
func HTTPSRedirect() echo.MiddlewareFunc {
return HTTPSRedirectWithConfig(DefaultRedirectConfig)
return HTTPSRedirectWithConfig(RedirectHTTPSConfig)
}
// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `HTTPSRedirect()`.
// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration.
func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if scheme != "https" {
return true, "https://" + host + uri
}
return false, ""
})
config.redirect = redirectHTTPS
return toMiddlewareOrPanic(config)
}
// HTTPSWWWRedirect redirects http requests to https www.
@ -54,18 +61,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(HTTPSWWWRedirect())`
func HTTPSWWWRedirect() echo.MiddlewareFunc {
return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig)
return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig)
}
// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `HTTPSWWWRedirect()`.
// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration.
func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if scheme != "https" && !strings.HasPrefix(host, www) {
return true, "https://www." + host + uri
}
return false, ""
})
config.redirect = redirectHTTPSWWW
return toMiddlewareOrPanic(config)
}
// HTTPSNonWWWRedirect redirects http requests to https non www.
@ -73,19 +75,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(HTTPSNonWWWRedirect())`
func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig)
return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig)
}
// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `HTTPSNonWWWRedirect()`.
// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration.
func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
if scheme != "https" {
host = strings.TrimPrefix(host, www)
return true, "https://" + host + uri
}
return false, ""
})
config.redirect = redirectNonHTTPSWWW
return toMiddlewareOrPanic(config)
}
// WWWRedirect redirects non www requests to www.
@ -93,18 +89,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(WWWRedirect())`
func WWWRedirect() echo.MiddlewareFunc {
return WWWRedirectWithConfig(DefaultRedirectConfig)
return WWWRedirectWithConfig(RedirectWWWConfig)
}
// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `WWWRedirect()`.
// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration.
func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if !strings.HasPrefix(host, www) {
return true, scheme + "://www." + host + uri
}
return false, ""
})
config.redirect = redirectWWW
return toMiddlewareOrPanic(config)
}
// NonWWWRedirect redirects www requests to non www.
@ -112,26 +103,25 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(NonWWWRedirect())`
func NonWWWRedirect() echo.MiddlewareFunc {
return NonWWWRedirectWithConfig(DefaultRedirectConfig)
return NonWWWRedirectWithConfig(RedirectNonWWWConfig)
}
// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `NonWWWRedirect()`.
// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration.
func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if strings.HasPrefix(host, www) {
return true, scheme + "://" + host[4:] + uri
}
return false, ""
})
config.redirect = redirectNonWWW
return toMiddlewareOrPanic(config)
}
func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration
func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultRedirectConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.Code == 0 {
config.Code = DefaultRedirectConfig.Code
config.Code = http.StatusMovedPermanently
}
if config.redirect == nil {
return nil, errors.New("redirectConfig is missing redirect function")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -142,11 +132,47 @@ func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
req, scheme := c.Request(), c.Scheme()
host := req.Host
if ok, url := cb(scheme, host, req.RequestURI); ok {
if ok, url := config.redirect(scheme, host, req.RequestURI); ok {
return c.Redirect(config.Code, url)
}
return next(c)
}
}, nil
}
var redirectHTTPS = func(scheme, host, uri string) (bool, string) {
if scheme != "https" {
return true, "https://" + host + uri
}
return false, ""
}
var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) {
if scheme != "https" && !strings.HasPrefix(host, www) {
return true, "https://www." + host + uri
}
return false, ""
}
var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) {
if scheme != "https" {
host = strings.TrimPrefix(host, www)
return true, "https://" + host + uri
}
return false, ""
}
var redirectWWW = func(scheme, host, uri string) (bool, string) {
if !strings.HasPrefix(host, www) {
return true, scheme + "://www." + host + uri
}
return false, ""
}
var redirectNonWWW = func(scheme, host, uri string) (bool, string) {
if strings.HasPrefix(host, www) {
return true, scheme + "://" + host[4:] + uri
}
return false, ""
}

View File

@ -5,7 +5,7 @@ import (
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)

View File

@ -1,13 +1,11 @@
package middleware
import (
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/labstack/echo/v5"
)
type (
// RequestIDConfig defines the config for RequestID middleware.
RequestIDConfig struct {
type RequestIDConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -16,35 +14,29 @@ type (
Generator func() string
// RequestIDHandler defines a function which is executed for a request id.
RequestIDHandler func(echo.Context, string)
RequestIDHandler func(c echo.Context, requestID string)
// TargetHeader defines what header to look for to populate the id
TargetHeader string
}
)
var (
// DefaultRequestIDConfig is the default RequestID middleware config.
DefaultRequestIDConfig = RequestIDConfig{
Skipper: DefaultSkipper,
Generator: generator,
TargetHeader: echo.HeaderXRequestID,
}
)
// RequestID returns a X-Request-ID middleware.
func RequestID() echo.MiddlewareFunc {
return RequestIDWithConfig(DefaultRequestIDConfig)
return RequestIDWithConfig(RequestIDConfig{})
}
// RequestIDWithConfig returns a X-Request-ID middleware with config.
// RequestIDWithConfig returns a X-Request-ID middleware with config or panics on invalid configuration.
func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration
func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultRequestIDConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.Generator == nil {
config.Generator = generator
config.Generator = createRandomStringGenerator(32)
}
if config.TargetHeader == "" {
config.TargetHeader = echo.HeaderXRequestID
@ -69,9 +61,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
return next(c)
}
}
}
func generator() string {
return random.String(32)
}, nil
}

View File

@ -5,7 +5,7 @@ import (
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@ -18,25 +18,104 @@ func TestRequestID(t *testing.T) {
return c.String(http.StatusOK, "test")
}
rid := RequestIDWithConfig(RequestIDConfig{})
rid := RequestID()
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
}
func TestMustRequestIDWithConfig_skipper(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
generatorCalled := false
e.Use(RequestIDWithConfig(RequestIDConfig{
Skipper: func(c echo.Context) bool {
return true
},
Generator: func() string {
generatorCalled = true
return "customGenerator"
},
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
assert.Equal(t, "test", res.Body.String())
assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "")
assert.False(t, generatorCalled)
}
func TestMustRequestIDWithConfig_customGenerator(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid := RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return "customGenerator" },
})
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
}
func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
called := false
rid := RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return "customGenerator" },
RequestIDHandler: func(c echo.Context, s string) {
called = true
},
})
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
assert.True(t, called)
}
func TestRequestIDWithConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid, err := RequestIDConfig{}.ToMiddleware()
assert.NoError(t, err)
h := rid(handler)
h(c)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
// Custom generator and handler
customID := "customGenerator"
calledHandler := false
// Custom generator
rid = RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return customID },
RequestIDHandler: func(_ echo.Context, id string) {
calledHandler = true
assert.Equal(t, customID, id)
},
Generator: func() string { return "customGenerator" },
})
h = rid(handler)
h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
assert.True(t, calledHandler)
}
func TestRequestID_IDNotAltered(t *testing.T) {

View File

@ -2,7 +2,7 @@ package middleware
import (
"errors"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"net/http"
"time"
)
@ -24,6 +24,7 @@ import (
// LogStatus: true,
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// logger.Info().
// Date("request_start", v.StartTime).
// Str("URI", v.URI).
// Int("status", v.Status).
// Msg("request")
@ -39,6 +40,7 @@ import (
// LogStatus: true,
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// logger.Info("request",
// zap.Time("request_start", v.StartTime),
// zap.String("URI", v.URI),
// zap.Int("status", v.Status),
// )
@ -54,6 +56,7 @@ import (
// LogStatus: true,
// LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error {
// log.WithFields(logrus.Fields{
// "request_start": values.StartTime,
// "URI": values.URI,
// "status": values.Status,
// }).Info("request")
@ -158,15 +161,15 @@ type RequestLoggerValues struct {
// ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct.
ResponseSize int64
// Headers are list of headers from request. Note: request can contain more than one header with same value so slice
// of values is been logger for each given header.
// of values is what will be returned/logged for each given header.
// Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
// names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
Headers map[string][]string
// QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter
// with same name so slice of values is been logger for each given query param name.
// with same name so slice of values is what will be returned/logged for each given query param name.
QueryParams map[string][]string
// FormValues are list of form values from request body+URI. Note: request can contain more than one form value with
// same name so slice of values is been logger for each given form value name.
// same name so slice of values is what will be returned/logged for each given form value name.
FormValues map[string][]string
}

View File

@ -1,7 +1,7 @@
package middleware
import (
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
@ -289,7 +289,7 @@ func TestRequestLogger_allFields(t *testing.T) {
req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c := e.NewContext(req, rec).(echo.ServableContext)
c.SetPath("/test*")

View File

@ -1,14 +1,14 @@
package middleware
import (
"errors"
"regexp"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
)
type (
// RewriteConfig defines the config for Rewrite middleware.
RewriteConfig struct {
type RewriteConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -20,43 +20,39 @@ type (
// "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
// Required.
Rules map[string]string `yaml:"rules"`
Rules map[string]string
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"`
RegexRules map[*regexp.Regexp]string
}
)
var (
// DefaultRewriteConfig is the default Rewrite middleware config.
DefaultRewriteConfig = RewriteConfig{
Skipper: DefaultSkipper,
}
)
// Rewrite returns a Rewrite middleware.
//
// Rewrite middleware rewrites the URL path based on the provided rules.
func Rewrite(rules map[string]string) echo.MiddlewareFunc {
c := DefaultRewriteConfig
c := RewriteConfig{}
c.Rules = rules
return RewriteWithConfig(c)
}
// RewriteWithConfig returns a Rewrite middleware with config.
// See: `Rewrite()`.
// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration.
//
// Rewrite middleware rewrites the URL path based on the provided rules.
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
// Defaults
if config.Rules == nil && config.RegexRules == nil {
panic("echo: rewrite middleware requires url path rewrite rules or regex rules")
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration
func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.Rules == nil && config.RegexRules == nil {
return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules")
}
if config.RegexRules == nil {
@ -77,5 +73,5 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
}
return next(c)
}
}
}, nil
}

View File

@ -8,7 +8,7 @@ import (
"regexp"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@ -24,10 +24,10 @@ func TestRewriteAfterRouting(t *testing.T) {
},
}))
e.GET("/public/*", func(c echo.Context) error {
return c.String(http.StatusOK, c.Param("*"))
return c.String(http.StatusOK, c.PathParam("*"))
})
e.GET("/*", func(c echo.Context) error {
return c.String(http.StatusOK, c.Param("*"))
return c.String(http.StatusOK, c.PathParam("*"))
})
var testCases = []struct {
@ -90,20 +90,74 @@ func TestRewriteAfterRouting(t *testing.T) {
}
}
func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) {
assert.Panics(t, func() {
RewriteWithConfig(RewriteConfig{})
})
}
func TestMustRewriteWithConfig_skipper(t *testing.T) {
var testCases = []struct {
name string
givenSkipper func(c echo.Context) bool
whenURL string
expectURL string
expectStatus int
}{
{
name: "not skipped",
whenURL: "/old",
expectURL: "/new",
expectStatus: http.StatusOK,
},
{
name: "skipped",
givenSkipper: func(c echo.Context) bool {
return true
},
whenURL: "/old",
expectURL: "/old",
expectStatus: http.StatusNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Pre(RewriteWithConfig(
RewriteConfig{
Skipper: tc.givenSkipper,
Rules: map[string]string{"/old": "/new"}},
))
e.GET("/new", func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectURL, req.URL.EscapedPath())
assert.Equal(t, tc.expectStatus, rec.Code)
})
}
}
// Issue #1086
func TestEchoRewritePreMiddleware(t *testing.T) {
e := echo.New()
r := e.Router()
// Rewrite old url to new one
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(Rewrite(map[string]string{
"/old": "/new",
},
))
e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{"/old": "/new"}}),
)
// Route
r.Add(http.MethodGet, "/new", func(c echo.Context) error {
e.Add(http.MethodGet, "/new", func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
@ -117,7 +171,6 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
// Issue #1143
func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
e := echo.New()
r := e.Router()
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(RewriteWithConfig(RewriteConfig{
@ -127,10 +180,10 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
},
}))
r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
return c.String(http.StatusOK, "hosts")
})
r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
e.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
return c.String(http.StatusOK, "eng")
})

View File

@ -3,24 +3,23 @@ package middleware
import (
"fmt"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
)
type (
// SecureConfig defines the config for Secure middleware.
SecureConfig struct {
type SecureConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// XSSProtection provides protection against cross-site scripting attack (XSS)
// by setting the `X-XSS-Protection` header.
// Optional. Default value "1; mode=block".
XSSProtection string `yaml:"xss_protection"`
XSSProtection string
// ContentTypeNosniff provides protection against overriding Content-Type
// header by setting the `X-Content-Type-Options` header.
// Optional. Default value "nosniff".
ContentTypeNosniff string `yaml:"content_type_nosniff"`
ContentTypeNosniff string
// XFrameOptions can be used to indicate whether or not a browser should
// be allowed to render a page in a <frame>, <iframe> or <object> .
@ -32,58 +31,55 @@ type (
// - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself.
// - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so.
// - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin.
XFrameOptions string `yaml:"x_frame_options"`
XFrameOptions string
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how
// long (in seconds) browsers should remember that this site is only to
// be accessed using HTTPS. This reduces your exposure to some SSL-stripping
// man-in-the-middle (MITM) attacks.
// Optional. Default value 0.
HSTSMaxAge int `yaml:"hsts_max_age"`
HSTSMaxAge int
// HSTSExcludeSubdomains won't include subdomains tag in the `Strict Transport Security`
// header, excluding all subdomains from security policy. It has no effect
// unless HSTSMaxAge is set to a non-zero value.
// Optional. Default value false.
HSTSExcludeSubdomains bool `yaml:"hsts_exclude_subdomains"`
HSTSExcludeSubdomains bool
// ContentSecurityPolicy sets the `Content-Security-Policy` header providing
// security against cross-site scripting (XSS), clickjacking and other code
// injection attacks resulting from execution of malicious content in the
// trusted web page context.
// Optional. Default value "".
ContentSecurityPolicy string `yaml:"content_security_policy"`
ContentSecurityPolicy string
// CSPReportOnly would use the `Content-Security-Policy-Report-Only` header instead
// of the `Content-Security-Policy` header. This allows iterative updates of the
// content security policy by only reporting the violations that would
// have occurred instead of blocking the resource.
// Optional. Default value false.
CSPReportOnly bool `yaml:"csp_report_only"`
CSPReportOnly bool
// HSTSPreloadEnabled will add the preload tag in the `Strict Transport Security`
// header, which enables the domain to be included in the HSTS preload list
// maintained by Chrome (and used by Firefox and Safari): https://hstspreload.org/
// Optional. Default value false.
HSTSPreloadEnabled bool `yaml:"hsts_preload_enabled"`
HSTSPreloadEnabled bool
// ReferrerPolicy sets the `Referrer-Policy` header providing security against
// leaking potentially sensitive request paths to third parties.
// Optional. Default value "".
ReferrerPolicy string `yaml:"referrer_policy"`
ReferrerPolicy string
}
)
var (
// DefaultSecureConfig is the default Secure middleware config.
DefaultSecureConfig = SecureConfig{
var DefaultSecureConfig = SecureConfig{
Skipper: DefaultSkipper,
XSSProtection: "1; mode=block",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
HSTSPreloadEnabled: false,
}
)
// Secure returns a Secure middleware.
// Secure middleware provides protection against cross-site scripting (XSS) attack,
@ -93,9 +89,13 @@ func Secure() echo.MiddlewareFunc {
return SecureWithConfig(DefaultSecureConfig)
}
// SecureWithConfig returns a Secure middleware with config.
// See: `Secure()`.
// SecureWithConfig returns a Secure middleware with config or panics on invalid configuration.
func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts SecureConfig to middleware or returns an error for invalid configuration
func (config SecureConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultSecureConfig.Skipper
@ -141,5 +141,5 @@ func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
}
return next(c)
}
}
}, nil
}

View File

@ -5,7 +5,7 @@ import (
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@ -19,26 +19,40 @@ func TestSecure(t *testing.T) {
}
// Default
Secure()(h)(c)
err := Secure()(h)(c)
assert.NoError(t, err)
assert.Equal(t, "1; mode=block", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "nosniff", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "SAMEORIGIN", rec.Header().Get(echo.HeaderXFrameOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderStrictTransportSecurity))
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy))
assert.Equal(t, "", rec.Header().Get(echo.HeaderReferrerPolicy))
}
// Custom
func TestSecureWithConfig(t *testing.T) {
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw, err := SecureConfig{
XSSProtection: "",
ContentTypeNosniff: "",
XFrameOptions: "",
HSTSMaxAge: 3600,
ContentSecurityPolicy: "default-src 'self'",
ReferrerPolicy: "origin",
})(h)(c)
}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions))
@ -47,11 +61,21 @@ func TestSecure(t *testing.T) {
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly))
assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy))
}
func TestSecureWithConfig_CSPReportOnly(t *testing.T) {
// Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := SecureWithConfig(SecureConfig{
XSSProtection: "",
ContentTypeNosniff: "",
XFrameOptions: "",
@ -60,6 +84,8 @@ func TestSecure(t *testing.T) {
CSPReportOnly: true,
ReferrerPolicy: "origin",
})(h)(c)
assert.NoError(t, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions))
@ -67,25 +93,51 @@ func TestSecure(t *testing.T) {
assert.Equal(t, "default-src 'self'", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly))
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy))
assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy))
}
func TestSecureWithConfig_HSTSPreloadEnabled(t *testing.T) {
// Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
// Custom, with preload option enabled
req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := SecureWithConfig(SecureConfig{
HSTSMaxAge: 3600,
HSTSPreloadEnabled: true,
})(h)(c)
assert.NoError(t, err)
assert.Equal(t, "max-age=3600; includeSubdomains; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity))
}
func TestSecureWithConfig_HSTSExcludeSubdomains(t *testing.T) {
// Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
// Custom, with preload option enabled and subdomains excluded
req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := SecureWithConfig(SecureConfig{
HSTSMaxAge: 3600,
HSTSPreloadEnabled: true,
HSTSExcludeSubdomains: true,
})(h)(c)
assert.NoError(t, err)
assert.Equal(t, "max-age=3600; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity))
}

View File

@ -1,44 +1,45 @@
package middleware
import (
"errors"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
)
type (
// TrailingSlashConfig defines the config for TrailingSlash middleware.
TrailingSlashConfig struct {
// AddTrailingSlashConfig is the middleware config for adding trailing slash to the request.
type AddTrailingSlashConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Status code to be used when redirecting the request.
// Optional, but when provided the request is redirected using this code.
RedirectCode int `yaml:"redirect_code"`
// Valid status codes: [300...308]
RedirectCode int
}
)
var (
// DefaultTrailingSlashConfig is the default TrailingSlash middleware config.
DefaultTrailingSlashConfig = TrailingSlashConfig{
Skipper: DefaultSkipper,
}
)
// AddTrailingSlash returns a root level (before router) middleware which adds a
// trailing slash to the request `URL#Path`.
//
// Usage `Echo#Pre(AddTrailingSlash())`
func AddTrailingSlash() echo.MiddlewareFunc {
return AddTrailingSlashWithConfig(DefaultTrailingSlashConfig)
return AddTrailingSlashWithConfig(AddTrailingSlashConfig{})
}
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config.
// See `AddTrailingSlash()`.
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
// Defaults
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config or panics on invalid configuration.
func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts AddTrailingSlashConfig to middleware or returns an error for invalid configuration
func (config AddTrailingSlashConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.RedirectCode != 0 && (config.RedirectCode < http.StatusMultipleChoices || config.RedirectCode > http.StatusPermanentRedirect) {
// this is same check as `echo.context.Redirect()` does, but we can check this before even serving the request.
return nil, errors.New("invalid redirect code for add trailing slash middleware")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -69,7 +70,17 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
}
return next(c)
}
}, nil
}
// RemoveTrailingSlashConfig is the middleware config for removing trailing slash from the request.
type RemoveTrailingSlashConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Status code to be used when redirecting the request.
// Optional, but when provided the request is redirected using this code.
RedirectCode int
}
// RemoveTrailingSlash returns a root level (before router) middleware which removes
@ -77,15 +88,22 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
//
// Usage `Echo#Pre(RemoveTrailingSlash())`
func RemoveTrailingSlash() echo.MiddlewareFunc {
return RemoveTrailingSlashWithConfig(TrailingSlashConfig{})
return RemoveTrailingSlashWithConfig(RemoveTrailingSlashConfig{})
}
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config.
// See `RemoveTrailingSlash()`.
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
// Defaults
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config or panics on invalid configuration.
func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RemoveTrailingSlashConfig to middleware or returns an error for invalid configuration
func (config RemoveTrailingSlashConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.RedirectCode != 0 && (config.RedirectCode < http.StatusMultipleChoices || config.RedirectCode > http.StatusPermanentRedirect) {
// this is same check as `echo.context.Redirect()` does, but we can check this before even serving the request.
return nil, errors.New("invalid redirect code for remove trailing slash middleware")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -117,7 +135,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
}
return next(c)
}
}
}, nil
}
func sanitizeURI(uri string) string {

View File

@ -5,7 +5,7 @@ import (
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@ -67,7 +67,7 @@ func TestAddTrailingSlashWithConfig(t *testing.T) {
t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New()
mw := AddTrailingSlashWithConfig(TrailingSlashConfig{
mw := AddTrailingSlashWithConfig(AddTrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently,
})
h := mw(func(c echo.Context) error {
@ -203,7 +203,7 @@ func TestRemoveTrailingSlashWithConfig(t *testing.T) {
t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New()
mw := RemoveTrailingSlashWithConfig(TrailingSlashConfig{
mw := RemoveTrailingSlashWithConfig(RemoveTrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently,
})
h := mw(func(c echo.Context) error {

View File

@ -1,55 +1,65 @@
package middleware
import (
"errors"
"fmt"
"html/template"
"io"
"io/fs"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/bytes"
"github.com/labstack/echo/v5"
)
type (
// StaticConfig defines the config for Static middleware.
StaticConfig struct {
type StaticConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Root directory from where the static content is served.
// Root directory from where the static content is served (relative to given Filesystem).
// `Root: "."` means root folder from Filesystem.
// Required.
Root string `yaml:"root"`
Root string
// Filesystem provides access to the static content.
// Optional. Defaults to echo.Filesystem (serves files from `.` folder where executable is started)
Filesystem fs.FS
// Index file for serving a directory.
// Optional. Default value "index.html".
Index string `yaml:"index"`
Index string
// Enable HTML5 mode by forwarding all not-found requests to root so that
// SPA (single-page application) can handle the routing.
// Optional. Default value false.
HTML5 bool `yaml:"html5"`
HTML5 bool
// Enable directory browsing.
// Optional. Default value false.
Browse bool `yaml:"browse"`
Browse bool
// Enable ignoring of the base of the URL path.
// Example: when assigning a static middleware to a non root path group,
// the filesystem path is not doubled
// Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"`
IgnoreBase bool
// Filesystem provides access to the static content.
// Optional. Defaults to http.Dir(config.Root)
Filesystem http.FileSystem `yaml:"-"`
// DisablePathUnescaping disables path parameter (param: *) unescaping. This is useful when router is set to unescape
// all parameter and doing it again in this middleware would corrupt filename that is requested.
DisablePathUnescaping bool
// DirectoryListTemplate is template to list directory contents
// Optional. Default to `directoryListHTMLTemplate` constant below.
DirectoryListTemplate string
}
)
const html = `
const directoryListHTMLTemplate = `
<!DOCTYPE html>
<html lang="en">
<head>
@ -121,25 +131,26 @@ const html = `
</html>
`
var (
// DefaultStaticConfig is the default Static middleware config.
DefaultStaticConfig = StaticConfig{
var DefaultStaticConfig = StaticConfig{
Skipper: DefaultSkipper,
Index: "index.html",
}
)
// Static returns a Static middleware to serves static content from the provided
// root directory.
// Static returns a Static middleware to serves static content from the provided root directory.
func Static(root string) echo.MiddlewareFunc {
c := DefaultStaticConfig
c.Root = root
return StaticWithConfig(c)
}
// StaticWithConfig returns a Static middleware with config.
// See `Static()`.
// StaticWithConfig returns a Static middleware to serves static content or panics on invalid configuration.
func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts StaticConfig to middleware or returns an error for invalid configuration
func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Root == "" {
config.Root = "." // For security we want to restrict to CWD.
@ -150,30 +161,32 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
if config.Index == "" {
config.Index = DefaultStaticConfig.Index
}
if config.Filesystem == nil {
config.Filesystem = http.Dir(config.Root)
config.Root = "."
if config.DirectoryListTemplate == "" {
config.DirectoryListTemplate = directoryListHTMLTemplate
}
// Index template
t, err := template.New("index").Parse(html)
dirListTemplate, err := template.New("index").Parse(config.DirectoryListTemplate)
if err != nil {
panic(fmt.Sprintf("echo: %v", err))
return nil, fmt.Errorf("echo static middleware directory list template parsing error: %w", err)
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
p := c.Request().URL.Path
pathUnescape := true
if strings.HasSuffix(c.Path(), "*") { // When serving from a group, e.g. `/static*`.
p = c.Param("*")
p = c.PathParam("*")
pathUnescape = !config.DisablePathUnescaping // because router could already do PathUnescape
}
if pathUnescape {
p, err = url.PathUnescape(p)
if err != nil {
return
return err
}
}
name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security
@ -186,22 +199,29 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
}
}
file, err := openFile(config.Filesystem, name)
currentFS := config.Filesystem
if currentFS == nil {
currentFS = c.Echo().Filesystem
}
file, err := openFile(currentFS, name)
if err != nil {
if !os.IsNotExist(err) {
return err
}
if err = next(c); err == nil {
return err
// when file does not exist let handler to handle that request. if it succeeds then we are done
err = next(c)
if err == nil {
return nil
}
he, ok := err.(*echo.HTTPError)
if !(ok && config.HTML5 && he.Code == http.StatusNotFound) {
return err
}
file, err = openFile(config.Filesystem, filepath.Join(config.Root, config.Index))
// is case HTML5 mode is enabled + echo 404 we serve index to the client
file, err = openFile(currentFS, filepath.Join(config.Root, config.Index))
if err != nil {
return err
}
@ -215,10 +235,10 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
}
if info.IsDir() {
index, err := openFile(config.Filesystem, filepath.Join(name, config.Index))
index, err := openFile(currentFS, filepath.Join(name, config.Index))
if err != nil {
if config.Browse {
return listDir(t, name, file, c.Response())
return listDir(dirListTemplate, name, currentFS, file, c.Response())
}
if os.IsNotExist(err) {
@ -238,25 +258,24 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
return serveFile(c, file, info)
}
}
}, nil
}
func openFile(fs http.FileSystem, name string) (http.File, error) {
func openFile(fs fs.FS, name string) (fs.File, error) {
pathWithSlashes := filepath.ToSlash(name)
return fs.Open(pathWithSlashes)
}
func serveFile(c echo.Context, file http.File, info os.FileInfo) error {
http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), file)
func serveFile(c echo.Context, file fs.File, info os.FileInfo) error {
ff, ok := file.(io.ReadSeeker)
if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), ff)
return nil
}
func listDir(t *template.Template, name string, dir http.File, res *echo.Response) (err error) {
files, err := dir.Readdir(-1)
if err != nil {
return
}
func listDir(t *template.Template, name string, filesystem fs.FS, dir fs.File, res *echo.Response) error {
// Create directory index
res.Header().Set(echo.HeaderContentType, echo.MIMETextHTMLCharsetUTF8)
data := struct {
@ -265,12 +284,60 @@ func listDir(t *template.Template, name string, dir http.File, res *echo.Respons
}{
Name: name,
}
for _, f := range files {
err := fs.WalkDir(filesystem, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
info, infoErr := d.Info()
if infoErr != nil {
return fmt.Errorf("static middleware list dir error when getting file info: %w", err)
}
data.Files = append(data.Files, struct {
Name string
Dir bool
Size string
}{f.Name(), f.IsDir(), bytes.Format(f.Size())})
}{d.Name(), d.IsDir(), format(info.Size())})
return nil
})
if err != nil {
return err
}
return t.Execute(res, data)
}
// format formats bytes integer to human readable string.
// For example, 31323 bytes will return 30.59KB.
func format(b int64) string {
multiple := ""
value := float64(b)
switch {
case b >= EB:
value /= float64(EB)
multiple = "EB"
case b >= PB:
value /= float64(PB)
multiple = "PB"
case b >= TB:
value /= float64(TB)
multiple = "TB"
case b >= GB:
value /= float64(GB)
multiple = "GB"
case b >= MB:
value /= float64(MB)
multiple = "MB"
case b >= KB:
value /= float64(KB)
multiple = "KB"
case b == 0:
return "0"
default:
return strconv.FormatInt(b, 10) + "B"
}
return fmt.Sprintf("%.2f%s", value, multiple)
}

View File

@ -1,106 +0,0 @@
// +build go1.16
package middleware
import (
"io/fs"
"net/http"
"net/http/httptest"
"os"
"testing"
"testing/fstest"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestStatic_CustomFS(t *testing.T) {
var testCases = []struct {
name string
filesystem fs.FS
root string
whenURL string
expectContains string
expectCode int
}{
{
name: "ok, serve index with Echo message",
whenURL: "/",
filesystem: os.DirFS("../_fixture"),
expectCode: http.StatusOK,
expectContains: "<title>Echo</title>",
},
{
name: "ok, serve index with Echo message",
whenURL: "/_fixture/",
filesystem: os.DirFS(".."),
expectCode: http.StatusOK,
expectContains: "<title>Echo</title>",
},
{
name: "ok, serve file from map fs",
whenURL: "/file.txt",
filesystem: fstest.MapFS{
"file.txt": &fstest.MapFile{Data: []byte("file.txt is ok")},
},
expectCode: http.StatusOK,
expectContains: "file.txt is ok",
},
{
name: "nok, missing file in map fs",
whenURL: "/file.txt",
expectCode: http.StatusNotFound,
filesystem: fstest.MapFS{
"file2.txt": &fstest.MapFile{Data: []byte("file2.txt is ok")},
},
},
{
name: "nok, file is not a subpath of root",
whenURL: `/../../secret.txt`,
root: "/nested/folder",
filesystem: fstest.MapFS{
"secret.txt": &fstest.MapFile{Data: []byte("this is a secret")},
},
expectCode: http.StatusNotFound,
},
{
name: "nok, backslash is forbidden",
whenURL: `/..\..\secret.txt`,
expectCode: http.StatusNotFound,
root: "/nested/folder",
filesystem: fstest.MapFS{
"secret.txt": &fstest.MapFile{Data: []byte("this is a secret")},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
config := StaticConfig{
Root: ".",
Filesystem: http.FS(tc.filesystem),
}
if tc.root != "" {
config.Root = tc.root
}
middlewareFunc := StaticWithConfig(config)
e.Use(middlewareFunc)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectCode, rec.Code)
if tc.expectContains != "" {
responseBody := rec.Body.String()
assert.Contains(t, responseBody, tc.expectContains)
}
})
}
}

View File

@ -1,15 +1,49 @@
package middleware
import (
"io/fs"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"testing/fstest"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestStatic_useCaseForApiAndSPAs(t *testing.T) {
e := echo.New()
// serve single page application (SPA) files from server root
e.Use(StaticWithConfig(StaticConfig{
Root: ".",
// by default Echo filesystem is fixed to `./` but this does not allow `../` (moving up in folder structure past filesystem root)
Filesystem: os.DirFS("../_fixture"),
}))
// all requests to `/api/*` will end up in echo handlers (assuming there is not `api` folder and files)
api := e.Group("/api")
users := api.Group("/users")
users.GET("/info", func(c echo.Context) error {
return c.String(http.StatusOK, "users info")
})
req := httptest.NewRequest(http.MethodGet, "/api/users/info", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "users info", rec.Body.String())
req = httptest.NewRequest(http.MethodGet, "/index.html", nil)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, rec.Body.String(), "<title>Echo</title>")
}
func TestStatic(t *testing.T) {
var testCases = []struct {
name string
@ -35,7 +69,7 @@ func TestStatic(t *testing.T) {
{
name: "ok, when html5 mode serve index for any static file that does not exist",
givenConfig: &StaticConfig{
Root: "../_fixture",
Root: "_fixture",
HTML5: true,
},
whenURL: "/random",
@ -45,7 +79,7 @@ func TestStatic(t *testing.T) {
{
name: "ok, serve index as directory index listing files directory",
givenConfig: &StaticConfig{
Root: "../_fixture/certs",
Root: "_fixture/certs",
Browse: true,
},
whenURL: "/",
@ -55,7 +89,7 @@ func TestStatic(t *testing.T) {
{
name: "ok, serve directory index with IgnoreBase and browse",
givenConfig: &StaticConfig{
Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
Root: "_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
IgnoreBase: true,
Browse: true,
},
@ -67,7 +101,7 @@ func TestStatic(t *testing.T) {
{
name: "ok, serve file with IgnoreBase",
givenConfig: &StaticConfig{
Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
Root: "_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
IgnoreBase: true,
Browse: true,
},
@ -95,15 +129,27 @@ func TestStatic(t *testing.T) {
expectContains: "{\"message\":\"Not Found\"}\n",
},
{
name: "ok, do not serve file, when a handler took care of the request",
name: "ok, when no file then a handler will care of the request",
whenURL: "/regular-handler",
expectCode: http.StatusOK,
expectContains: "ok",
},
{
name: "ok, skip middleware and serve handler",
givenConfig: &StaticConfig{
Root: "_fixture/images/",
Skipper: func(c echo.Context) bool {
return true
},
},
whenURL: "/walle.png",
expectCode: http.StatusTeapot,
expectContains: "walle",
},
{
name: "nok, when html5 fail if the index file does not exist",
givenConfig: &StaticConfig{
Root: "../_fixture",
Root: "_fixture",
HTML5: true,
Index: "missing.html",
},
@ -114,7 +160,7 @@ func TestStatic(t *testing.T) {
name: "ok, serve from http.FileSystem",
givenConfig: &StaticConfig{
Root: "_fixture",
Filesystem: http.Dir(".."),
Filesystem: os.DirFS(".."),
},
whenURL: "/",
expectCode: http.StatusOK,
@ -125,8 +171,9 @@ func TestStatic(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Filesystem = os.DirFS("../")
config := StaticConfig{Root: "../_fixture"}
config := StaticConfig{Root: "_fixture"}
if tc.givenConfig != nil {
config = *tc.givenConfig
}
@ -136,14 +183,17 @@ func TestStatic(t *testing.T) {
subGroup := e.Group(tc.givenAttachedToGroup, middlewareFunc)
// group without http handlers (routes) does not do anything.
// Request is matched against http handlers (routes) that have group middleware attached to them
subGroup.GET("", echo.NotFoundHandler)
subGroup.GET("/*", echo.NotFoundHandler)
subGroup.GET("", func(c echo.Context) error { return echo.ErrNotFound })
subGroup.GET("/*", func(c echo.Context) error { return echo.ErrNotFound })
} else {
// middleware is on root level
e.Use(middlewareFunc)
e.GET("/regular-handler", func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
e.GET("/walle.png", func(c echo.Context) error {
return c.String(http.StatusTeapot, "walle")
})
}
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
@ -177,7 +227,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "ok",
givenPrefix: "/images",
givenRoot: "../_fixture/images",
givenRoot: "_fixture/images",
whenURL: "/group/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
@ -185,7 +235,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "No file",
givenPrefix: "/images",
givenRoot: "../_fixture/scripts",
givenRoot: "_fixture/scripts",
whenURL: "/group/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -193,7 +243,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Directory not found (no trailing slash)",
givenPrefix: "/images",
givenRoot: "../_fixture/images",
givenRoot: "_fixture/images",
whenURL: "/group/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -201,7 +251,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Directory redirect",
givenPrefix: "/",
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/group/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/group/folder/",
@ -211,7 +261,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
name: "Prefixed directory 404 (request URL without slash)",
givenGroup: "_fixture",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/_fixture/folder", // no trailing slash
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -220,7 +270,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
name: "Prefixed directory redirect (without slash redirect to slash)",
givenGroup: "_fixture",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/_fixture/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/_fixture/folder/",
@ -229,7 +279,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Directory with index.html",
givenPrefix: "/",
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/group/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
@ -237,7 +287,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/",
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/group/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
@ -245,7 +295,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets",
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/group/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
@ -253,7 +303,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Sub-directory with index.html",
givenPrefix: "/",
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/group/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
@ -261,7 +311,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
givenRoot: "../_fixture/",
givenRoot: "_fixture/",
whenURL: `/group/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -269,7 +319,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
givenRoot: "../_fixture/",
givenRoot: "_fixture/",
whenURL: `/group/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -279,6 +329,8 @@ func TestStatic_GroupWithStatic(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Filesystem = os.DirFS("../") // so we can access test files
group := "/group"
if tc.givenGroup != "" {
group = tc.givenGroup
@ -288,7 +340,9 @@ func TestStatic_GroupWithStatic(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String()
if tc.expectBodyStartsWith != "" {
@ -306,3 +360,166 @@ func TestStatic_GroupWithStatic(t *testing.T) {
})
}
}
func TestMustStaticWithConfig_panicsInvalidDirListTemplate(t *testing.T) {
assert.Panics(t, func() {
StaticWithConfig(StaticConfig{DirectoryListTemplate: `{{}`})
})
}
func TestFormat(t *testing.T) {
var testCases = []struct {
name string
when int64
expect string
}{
{
name: "byte",
when: 0,
expect: "0",
},
{
name: "bytes",
when: 515,
expect: "515B",
},
{
name: "KB",
when: 31323,
expect: "30.59KB",
},
{
name: "MB",
when: 13231323,
expect: "12.62MB",
},
{
name: "GB",
when: 7323232398,
expect: "6.82GB",
},
{
name: "TB",
when: 1_099_511_627_776,
expect: "1.00TB",
},
{
name: "PB",
when: 9923232398434432,
expect: "8.81PB",
},
{
// test with 7EB because of https://github.com/labstack/gommon/pull/38 and https://github.com/labstack/gommon/pull/43
//
// 8 exbi equals 2^64, therefore it cannot be stored in int64. The tests use
// the fact that on x86_64 the following expressions holds true:
// int64(0) - 1 == math.MaxInt64.
//
// However, this is not true for other platforms, specifically aarch64, s390x
// and ppc64le.
name: "EB",
when: 8070450532247929000,
expect: "7.00EB",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := format(tc.when)
assert.Equal(t, tc.expect, result)
})
}
}
func TestStatic_CustomFS(t *testing.T) {
var testCases = []struct {
name string
filesystem fs.FS
root string
whenURL string
expectContains string
expectCode int
}{
{
name: "ok, serve index with Echo message",
whenURL: "/",
filesystem: os.DirFS("../_fixture"),
expectCode: http.StatusOK,
expectContains: "<title>Echo</title>",
},
{
name: "ok, serve index with Echo message",
whenURL: "/_fixture/",
filesystem: os.DirFS(".."),
expectCode: http.StatusOK,
expectContains: "<title>Echo</title>",
},
{
name: "ok, serve file from map fs",
whenURL: "/file.txt",
filesystem: fstest.MapFS{
"file.txt": &fstest.MapFile{Data: []byte("file.txt is ok")},
},
expectCode: http.StatusOK,
expectContains: "file.txt is ok",
},
{
name: "nok, missing file in map fs",
whenURL: "/file.txt",
expectCode: http.StatusNotFound,
filesystem: fstest.MapFS{
"file2.txt": &fstest.MapFile{Data: []byte("file2.txt is ok")},
},
},
{
name: "nok, file is not a subpath of root",
whenURL: `/../../secret.txt`,
root: "/nested/folder",
filesystem: fstest.MapFS{
"secret.txt": &fstest.MapFile{Data: []byte("this is a secret")},
},
expectCode: http.StatusNotFound,
},
{
name: "nok, backslash is forbidden",
whenURL: `/..\..\secret.txt`,
expectCode: http.StatusNotFound,
root: "/nested/folder",
filesystem: fstest.MapFS{
"secret.txt": &fstest.MapFile{Data: []byte("this is a secret")},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
config := StaticConfig{
Root: ".",
Filesystem: tc.filesystem,
}
if tc.root != "" {
config.Root = tc.root
}
middlewareFunc, err := config.ToMiddleware()
assert.NoError(t, err)
e.Use(middlewareFunc)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectCode, rec.Code)
if tc.expectContains != "" {
responseBody := rec.Body.String()
assert.Contains(t, responseBody, tc.expectContains)
}
})
}
}

View File

@ -1,220 +0,0 @@
package middleware
import (
"context"
"github.com/labstack/echo/v4"
"net/http"
"sync"
"time"
)
// ---------------------------------------------------------------------------------------------------------------
// WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING
// WARNING: Timeout middleware causes more problems than it solves.
// WARNING: This middleware should be first middleware as it messes with request Writer and could cause data race if
// it is in other position
//
// Depending on out requirements you could be better of setting timeout to context and
// check its deadline from handler.
//
// For example: create middleware to set timeout to context
// func RequestTimeout(timeout time.Duration) echo.MiddlewareFunc {
// return func(next echo.HandlerFunc) echo.HandlerFunc {
// return func(c echo.Context) error {
// timeoutCtx, cancel := context.WithTimeout(c.Request().Context(), timeout)
// c.SetRequest(c.Request().WithContext(timeoutCtx))
// defer cancel()
// return next(c)
// }
// }
//}
//
// Create handler that checks for context deadline and runs actual task in separate coroutine
// Note: separate coroutine may not be even if you do not want to process continue executing and
// just want to stop long-running handler to stop and you are using "context aware" methods (ala db queries with ctx)
// e.GET("/", func(c echo.Context) error {
//
// doneCh := make(chan error)
// go func(ctx context.Context) {
// doneCh <- myPossiblyLongRunningBackgroundTaskWithCtx(ctx)
// }(c.Request().Context())
//
// select { // wait for task to finish or context to timeout/cancelled
// case err := <-doneCh:
// if err != nil {
// return err
// }
// return c.String(http.StatusOK, "OK")
// case <-c.Request().Context().Done():
// if c.Request().Context().Err() == context.DeadlineExceeded {
// return c.String(http.StatusServiceUnavailable, "timeout")
// }
// return c.Request().Context().Err()
// }
//
// })
//
// TimeoutConfig defines the config for Timeout middleware.
type TimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
// It can be used to define a custom timeout error message
ErrorMessage string
// OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
// request timeouted and we already had sent the error code (503) and message response to the client.
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
OnTimeoutRouteErrorHandler func(err error, c echo.Context)
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable
Timeout time.Duration
}
var (
// DefaultTimeoutConfig is the default Timeout middleware config.
DefaultTimeoutConfig = TimeoutConfig{
Skipper: DefaultSkipper,
Timeout: 0,
ErrorMessage: "",
}
)
// Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler
// call runs for longer than its time limit. NB: timeout does not stop handler execution.
func Timeout() echo.MiddlewareFunc {
return TimeoutWithConfig(DefaultTimeoutConfig)
}
// TimeoutWithConfig returns a Timeout middleware with config or panics on invalid configuration.
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
mw, err := config.ToMiddleware()
if err != nil {
panic(err)
}
return mw
}
// ToMiddleware converts Config to middleware or returns an error for invalid configuration
func (config TimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultTimeoutConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) || config.Timeout == 0 {
return next(c)
}
errChan := make(chan error, 1)
handlerWrapper := echoHandlerFuncWrapper{
writer: &ignorableWriter{ResponseWriter: c.Response().Writer},
ctx: c,
handler: next,
errChan: errChan,
errHandler: config.OnTimeoutRouteErrorHandler,
}
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
handler.ServeHTTP(handlerWrapper.writer, c.Request())
select {
case err := <-errChan:
return err
default:
return nil
}
}
}, nil
}
type echoHandlerFuncWrapper struct {
writer *ignorableWriter
ctx echo.Context
handler echo.HandlerFunc
errHandler func(err error, c echo.Context)
errChan chan error
}
func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// replace echo.Context Request with the one provided by TimeoutHandler to let later middlewares/handler on the chain
// handle properly it's cancellation
t.ctx.SetRequest(r)
// replace writer with TimeoutHandler custom one. This will guarantee that
// `writes by h to its ResponseWriter will return ErrHandlerTimeout.`
originalWriter := t.ctx.Response().Writer
t.ctx.Response().Writer = rw
// in case of panic we restore original writer and call panic again
// so it could be handled with global middleware Recover()
defer func() {
if err := recover(); err != nil {
t.ctx.Response().Writer = originalWriter
panic(err)
}
}()
err := t.handler(t.ctx)
if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded {
if err != nil && t.errHandler != nil {
t.errHandler(err, t.ctx)
}
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
}
if err != nil {
// This is needed as `http.TimeoutHandler` will write status code by itself on error and after that our tries to write
// status code will not work anymore as Echo.Response thinks it has been already "committed" and further writes
// create errors in log about `superfluous response.WriteHeader call from`
t.writer.Ignore(true)
t.ctx.Response().Writer = originalWriter // make sure we restore writer before we signal original coroutine about the error
// we pass error from handler to middlewares up in handler chain to act on it if needed.
t.errChan <- err
return
}
// we restore original writer only for cases we did not timeout. On timeout we have already sent response to client
// and should not anymore send additional headers/data
// so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body
t.ctx.Response().Writer = originalWriter
}
// ignorableWriter is ResponseWriter implementations that allows us to mark writer to ignore further write calls. This
// is handy in cases when you do not have direct control of code being executed (3rd party middleware) but want to make
// sure that external code will not be able to write response to the client.
// Writer is coroutine safe for writes.
type ignorableWriter struct {
http.ResponseWriter
lock sync.Mutex
ignoreWrites bool
}
func (w *ignorableWriter) Ignore(ignore bool) {
w.lock.Lock()
w.ignoreWrites = ignore
w.lock.Unlock()
}
func (w *ignorableWriter) WriteHeader(code int) {
w.lock.Lock()
defer w.lock.Unlock()
if w.ignoreWrites {
return
}
w.ResponseWriter.WriteHeader(code)
}
func (w *ignorableWriter) Write(b []byte) (int, error) {
w.lock.Lock()
defer w.lock.Unlock()
if w.ignoreWrites {
return len(b), nil
}
return w.ResponseWriter.Write(b)
}

View File

@ -1,484 +0,0 @@
package middleware
import (
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"strings"
"sync"
"testing"
"time"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestTimeoutSkipper(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
Skipper: func(context echo.Context) bool {
return true
},
Timeout: 1 * time.Nanosecond,
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
err := m(func(c echo.Context) error {
time.Sleep(25 * time.Microsecond)
return errors.New("response from handler")
})(c)
// if not skipped we would have not returned error due context timeout logic
assert.EqualError(t, err, "response from handler")
}
func TestTimeoutWithTimeout0(t *testing.T) {
t.Parallel()
m := Timeout()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
err := m(func(c echo.Context) error {
assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String())
return nil
})(c)
assert.NoError(t, err)
}
func TestTimeoutErrorOutInHandler(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 50 * time.Millisecond,
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
rec.Code = 1 // we want to be sure that even 200 will not be sent
err := m(func(c echo.Context) error {
// this error must not be written to the client response. Middlewares upstream of timeout middleware must be able
// to handle returned error and this can be done only then handler has not yet committed (written status code)
// the response.
return echo.NewHTTPError(http.StatusTeapot, "err")
})(c)
assert.Error(t, err)
assert.EqualError(t, err, "code=418, message=err")
assert.Equal(t, 1, rec.Code)
assert.Equal(t, "", rec.Body.String())
}
func TestTimeoutSuccessfulRequest(t *testing.T) {
t.Parallel()
m := TimeoutWithConfig(TimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 50 * time.Millisecond,
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
err := m(func(c echo.Context) error {
return c.JSON(http.StatusCreated, map[string]string{"data": "ok"})
})(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusCreated, rec.Code)
assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String())
}
func TestTimeoutOnTimeoutRouteErrorHandler(t *testing.T) {
t.Parallel()
actualErrChan := make(chan error, 1)
m := TimeoutWithConfig(TimeoutConfig{
Timeout: 1 * time.Millisecond,
OnTimeoutRouteErrorHandler: func(err error, c echo.Context) {
actualErrChan <- err
},
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
stopChan := make(chan struct{}, 0)
err := m(func(c echo.Context) error {
<-stopChan
return errors.New("error in route after timeout")
})(c)
stopChan <- struct{}{}
assert.NoError(t, err)
actualErr := <-actualErrChan
assert.EqualError(t, actualErr, "error in route after timeout")
}
func TestTimeoutTestRequestClone(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
m := TimeoutWithConfig(TimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 1 * time.Second,
})
e := echo.New()
c := e.NewContext(req, rec)
err := m(func(c echo.Context) error {
// Cookie test
cookie, err := c.Request().Cookie("cookie")
if assert.NoError(t, err) {
assert.EqualValues(t, "cookie", cookie.Name)
assert.EqualValues(t, "value", cookie.Value)
}
// Form values
if assert.NoError(t, c.Request().ParseForm()) {
assert.EqualValues(t, "value", c.Request().FormValue("form"))
}
// Query string
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
return nil
})(c)
assert.NoError(t, err)
}
func TestTimeoutRecoversPanic(t *testing.T) {
t.Parallel()
e := echo.New()
e.Use(Recover()) // recover middleware will handler our panic
e.Use(TimeoutWithConfig(TimeoutConfig{
Timeout: 50 * time.Millisecond,
}))
e.GET("/", func(c echo.Context) error {
panic("panic!!!")
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
assert.NotPanics(t, func() {
e.ServeHTTP(rec, req)
})
}
func TestTimeoutDataRace(t *testing.T) {
t.Parallel()
timeout := 1 * time.Millisecond
m := TimeoutWithConfig(TimeoutConfig{
Timeout: timeout,
ErrorMessage: "Timeout! change me",
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
err := m(func(c echo.Context) error {
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable
time.Sleep(timeout) // timeout and handler execution time difference is close to zero
return c.String(http.StatusOK, "Hello, World!")
})(c)
assert.NoError(t, err)
if rec.Code == http.StatusServiceUnavailable {
assert.Equal(t, "Timeout! change me", rec.Body.String())
} else {
assert.Equal(t, "Hello, World!", rec.Body.String())
}
}
func TestTimeoutWithErrorMessage(t *testing.T) {
t.Parallel()
timeout := 1 * time.Millisecond
m := TimeoutWithConfig(TimeoutConfig{
Timeout: timeout,
ErrorMessage: "Timeout! change me",
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
stopChan := make(chan struct{}, 0)
err := m(func(c echo.Context) error {
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable
<-stopChan
return c.String(http.StatusOK, "Hello, World!")
})(c)
stopChan <- struct{}{}
assert.NoError(t, err)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
assert.Equal(t, "Timeout! change me", rec.Body.String())
}
func TestTimeoutWithDefaultErrorMessage(t *testing.T) {
t.Parallel()
timeout := 1 * time.Millisecond
m := TimeoutWithConfig(TimeoutConfig{
Timeout: timeout,
ErrorMessage: "",
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
stopChan := make(chan struct{}, 0)
err := m(func(c echo.Context) error {
<-stopChan
return c.String(http.StatusOK, "Hello, World!")
})(c)
stopChan <- struct{}{}
assert.NoError(t, err)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
assert.Equal(t, `<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>`, rec.Body.String())
}
func TestTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) {
t.Parallel()
timeout := 1 * time.Millisecond
m := TimeoutWithConfig(TimeoutConfig{
Timeout: timeout,
ErrorMessage: "Timeout! change me",
})
handlerFinishedExecution := make(chan bool)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
stopChan := make(chan struct{})
err := m(func(c echo.Context) error {
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable
<-stopChan
// The Request Context should have a Deadline set by http.TimeoutHandler
if _, ok := c.Request().Context().Deadline(); !ok {
assert.Fail(t, "No timeout set on Request Context")
}
handlerFinishedExecution <- c.Request().Context().Err() == nil
return c.String(http.StatusOK, "Hello, World!")
})(c)
stopChan <- struct{}{}
assert.NoError(t, err)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
assert.Equal(t, "Timeout! change me", rec.Body.String())
assert.False(t, <-handlerFinishedExecution)
}
func TestTimeoutWithFullEchoStack(t *testing.T) {
// test timeout with full http server stack running, do see what http.Server.ErrorLog contains
var testCases = []struct {
name string
whenPath string
whenForceHandlerTimeout bool
expectStatusCode int
expectResponse string
expectLogContains []string
expectLogNotContains []string
}{
{
name: "404 - write response in global error handler",
whenPath: "/404",
expectResponse: "{\"message\":\"Not Found\"}\n",
expectStatusCode: http.StatusNotFound,
expectLogNotContains: []string{"echo:http: superfluous response.WriteHeader call from"},
expectLogContains: []string{`"status":404,"error":"code=404, message=Not Found"`},
},
{
name: "418 - write response in handler",
whenPath: "/",
expectResponse: "{\"message\":\"OK\"}\n",
expectStatusCode: http.StatusTeapot,
expectLogNotContains: []string{"echo:http: superfluous response.WriteHeader call from"},
expectLogContains: []string{`"status":418,"error":"",`},
},
{
name: "503 - handler timeouts, write response in timeout middleware",
whenForceHandlerTimeout: true,
whenPath: "/",
expectResponse: "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>",
expectStatusCode: http.StatusServiceUnavailable,
expectLogNotContains: []string{
"echo:http: superfluous response.WriteHeader call from",
},
expectLogContains: []string{"http: Handler timeout"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
buf := new(coroutineSafeBuffer)
e.Logger.SetOutput(buf)
// NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first
e.Use(TimeoutWithConfig(TimeoutConfig{
Timeout: 15 * time.Millisecond,
}))
e.Use(Logger())
e.Use(Recover())
wg := sync.WaitGroup{}
if tc.whenForceHandlerTimeout {
wg.Add(1) // make `wg.Wait()` block until we release it with `wg.Done()`
}
e.GET("/", func(c echo.Context) error {
wg.Wait()
return c.JSON(http.StatusTeapot, map[string]string{"message": "OK"})
})
server, addr, err := startServer(e)
if err != nil {
assert.NoError(t, err)
return
}
defer server.Close()
res, err := http.Get(fmt.Sprintf("http://%v%v", addr, tc.whenPath))
if err != nil {
assert.NoError(t, err)
return
}
if tc.whenForceHandlerTimeout {
wg.Done()
// shutdown waits for server to shutdown. this way we wait logger mw to be executed
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
defer cancel()
server.Shutdown(ctx)
}
assert.Equal(t, tc.expectStatusCode, res.StatusCode)
if body, err := ioutil.ReadAll(res.Body); err == nil {
assert.Equal(t, tc.expectResponse, string(body))
} else {
assert.Fail(t, err.Error())
}
logged := buf.String()
for _, subStr := range tc.expectLogContains {
assert.True(t, strings.Contains(logged, subStr), "expected logs to contain: %v, logged: '%v'", subStr, logged)
}
for _, subStr := range tc.expectLogNotContains {
assert.False(t, strings.Contains(logged, subStr), "expected logs not to contain: %v, logged: '%v'", subStr, logged)
}
})
}
}
// as we are spawning multiple coroutines - one for http server, one for request, one by timeout middleware, one by testcase
// we are accessing logger (writing/reading) from multiple coroutines and causing dataraces (most often reported on macos)
// we could be writing to logger in logger middleware and at the same time our tests is getting logger buffer contents
// in testcase coroutine.
type coroutineSafeBuffer struct {
bytes.Buffer
lock sync.RWMutex
}
func (b *coroutineSafeBuffer) Write(p []byte) (n int, err error) {
b.lock.Lock()
defer b.lock.Unlock()
return b.Buffer.Write(p)
}
func (b *coroutineSafeBuffer) Bytes() []byte {
b.lock.RLock()
defer b.lock.RUnlock()
return b.Buffer.Bytes()
}
func (b *coroutineSafeBuffer) String() string {
b.lock.RLock()
defer b.lock.RUnlock()
return b.Buffer.String()
}
func startServer(e *echo.Echo) (*http.Server, string, error) {
l, err := net.Listen("tcp", ":0")
if err != nil {
return nil, "", err
}
s := http.Server{
Handler: e,
ErrorLog: log.New(e.Logger.Output(), "echo:", 0),
}
errCh := make(chan error)
go func() {
if err := s.Serve(l); err != http.ErrServerClosed {
errCh <- err
}
}()
select {
case <-time.After(10 * time.Millisecond):
return &s, l.Addr().String(), nil
case err := <-errCh:
return nil, "", err
}
}

View File

@ -1,9 +1,27 @@
package middleware
import (
"crypto/rand"
"fmt"
"strings"
)
const (
_ = int64(1 << (10 * iota)) // ignore first value by assigning to blank identifier
// KB is 1 KiloByte = 1024 bytes
KB
// MB is 1 Megabyte = 1_048_576 bytes
MB
// GB is 1 Gigabyte = 1_073_741_824 bytes
GB
// TB is 1 Terabyte = 1_099_511_627_776 bytes
TB
// PB is 1 Petabyte = 1_125_899_906_842_624 bytes
PB
// EB is 1 Exabyte = 1_152_921_504_606_847_000 bytes
EB
)
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
@ -52,3 +70,24 @@ func matchSubdomain(domain, pattern string) bool {
}
return false
}
func createRandomStringGenerator(length uint8) func() string {
return func() string {
return randomString(length)
}
}
func randomString(length uint8) string {
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
// we are out of random. let the request fail
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
}
for i, b := range bytes {
bytes[i] = charset[b%byte(len(charset))]
}
return string(bytes)
}

View File

@ -1,11 +1,23 @@
package middleware
import (
"testing"
"github.com/stretchr/testify/assert"
"io"
"testing"
)
type testLogger struct {
output io.Writer
}
func (l *testLogger) Write(p []byte) (n int, err error) {
return l.output.Write(p)
}
func (l *testLogger) Error(err error) {
_, _ = l.output.Write([]byte(err.Error()))
}
func Test_matchScheme(t *testing.T) {
tests := []struct {
domain, pattern string
@ -93,3 +105,27 @@ func Test_matchSubdomain(t *testing.T) {
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
}
}
func TestRandomString(t *testing.T) {
var testCases = []struct {
name string
whenLength uint8
expect string
}{
{
name: "ok, 16",
whenLength: 16,
},
{
name: "ok, 32",
whenLength: 32,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
uid := randomString(tc.whenLength)
assert.Len(t, uid, int(tc.whenLength))
})
}
}

View File

@ -2,15 +2,15 @@ package echo
import (
"bufio"
"errors"
"net"
"net/http"
)
type (
// Response wraps an http.ResponseWriter and implements its interface to be used
// by an HTTP handler to construct an HTTP response.
// See: https://golang.org/pkg/net/http/#ResponseWriter
Response struct {
type Response struct {
echo *Echo
beforeFuncs []func()
afterFuncs []func()
@ -19,7 +19,6 @@ type (
Size int64
Committed bool
}
)
// NewResponse creates a new instance of Response.
func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) {
@ -47,13 +46,15 @@ func (r *Response) After(fn func()) {
r.afterFuncs = append(r.afterFuncs, fn)
}
var errHeaderAlreadyCommitted = errors.New("response already committed")
// WriteHeader sends an HTTP response header with status code. If WriteHeader is
// not called explicitly, the first call to Write will trigger an implicit
// WriteHeader(http.StatusOK). Thus explicit calls to WriteHeader are mainly
// used to send error codes.
func (r *Response) WriteHeader(code int) {
if r.Committed {
r.echo.Logger.Warn("response already committed")
r.echo.Logger.Error(errHeaderAlreadyCommitted)
return
}
r.Status = code

182
route.go Normal file
View File

@ -0,0 +1,182 @@
package echo
import (
"bytes"
"errors"
"fmt"
"reflect"
"runtime"
)
// Route contains information to adding/registering new route with the router.
// Method+Path pair uniquely identifies the Route. It is mandatory to provide Method+Path+Handler fields.
type Route struct {
Method string
Path string
Handler HandlerFunc
Middlewares []MiddlewareFunc
Name string
}
// ToRouteInfo converts Route to RouteInfo
func (r Route) ToRouteInfo(params []string) RouteInfo {
name := r.Name
if name == "" {
name = r.Method + ":" + r.Path
}
return routeInfo{
method: r.Method,
path: r.Path,
params: append([]string(nil), params...),
name: name,
}
}
// ToRoute returns Route which Router uses to register the method handler for path.
func (r Route) ToRoute() Route {
return r
}
// ForGroup recreates Route with added group prefix and group middlewares it is grouped to.
func (r Route) ForGroup(pathPrefix string, middlewares []MiddlewareFunc) Routable {
r.Path = pathPrefix + r.Path
if len(middlewares) > 0 {
m := make([]MiddlewareFunc, 0, len(middlewares)+len(r.Middlewares))
m = append(m, middlewares...)
m = append(m, r.Middlewares...)
r.Middlewares = m
}
return r
}
type routeInfo struct {
method string
path string
params []string
name string
}
func (r routeInfo) Method() string {
return r.method
}
func (r routeInfo) Path() string {
return r.path
}
func (r routeInfo) Params() []string {
return append([]string(nil), r.params...)
}
func (r routeInfo) Name() string {
return r.name
}
// Reverse reverses route to URL string by replacing path parameters with given params values.
func (r routeInfo) Reverse(params ...interface{}) string {
uri := new(bytes.Buffer)
ln := len(params)
n := 0
for i, l := 0, len(r.path); i < l; i++ {
if (r.path[i] == paramLabel || r.path[i] == anyLabel) && n < ln {
for ; i < l && r.path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))
n++
}
if i < l {
uri.WriteByte(r.path[i])
}
}
return uri.String()
}
// HandlerName returns string name for given function.
func HandlerName(h HandlerFunc) string {
t := reflect.ValueOf(h).Type()
if t.Kind() == reflect.Func {
return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
}
return t.String()
}
// Reverse reverses route to URL string by replacing path parameters with given params values.
func (r Routes) Reverse(name string, params ...interface{}) (string, error) {
for _, rr := range r {
if rr.Name() == name {
return rr.Reverse(params...), nil
}
}
return "", errors.New("route not found")
}
// FindByMethodPath searched for matching route info by method and path
func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error) {
if r == nil {
return nil, errors.New("route not found by method and path")
}
for _, rr := range r {
if rr.Method() == method && rr.Path() == path {
return rr, nil
}
}
return nil, errors.New("route not found by method and path")
}
// FilterByMethod searched for matching route info by method
func (r Routes) FilterByMethod(method string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by method")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Method() == method {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by method")
}
return result, nil
}
// FilterByPath searched for matching route info by path
func (r Routes) FilterByPath(path string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by path")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Path() == path {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by path")
}
return result, nil
}
// FilterByName searched for matching route info by name
func (r Routes) FilterByName(name string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by name")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Name() == name {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by name")
}
return result, nil
}

423
route_test.go Normal file
View File

@ -0,0 +1,423 @@
package echo
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)
var myNamedHandler = func(c Context) error {
return nil
}
type NameStruct struct {
}
func (n *NameStruct) getUsers(c Context) error {
return nil
}
func TestHandlerName(t *testing.T) {
myNameFuncVar := func(c Context) error {
return nil
}
tmp := NameStruct{}
var testCases = []struct {
name string
whenHandlerFunc HandlerFunc
expect string
}{
{
name: "ok, func as anonymous func",
whenHandlerFunc: func(c Context) error {
return nil
},
expect: "github.com/labstack/echo/v5.TestHandlerName.func2",
},
{
name: "ok, func as named package variable",
whenHandlerFunc: myNamedHandler,
expect: "github.com/labstack/echo/v5.glob..func4",
},
{
name: "ok, func as named function variable",
whenHandlerFunc: myNameFuncVar,
expect: "github.com/labstack/echo/v5.TestHandlerName.func1",
},
{
name: "ok, func as struct method",
whenHandlerFunc: tmp.getUsers,
expect: "github.com/labstack/echo/v5.(*NameStruct).getUsers-fm",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
name := HandlerName(tc.whenHandlerFunc)
assert.Equal(t, tc.expect, name)
})
}
}
func TestHandlerName_differentFuncSameName(t *testing.T) {
handlerCreator := func(name string) HandlerFunc {
return func(c Context) error {
return c.String(http.StatusTeapot, name)
}
}
h1 := handlerCreator("name1")
assert.Equal(t, "github.com/labstack/echo/v5.TestHandlerName_differentFuncSameName.func2", HandlerName(h1))
h2 := handlerCreator("name2")
assert.Equal(t, "github.com/labstack/echo/v5.TestHandlerName_differentFuncSameName.func3", HandlerName(h2))
}
func TestRoute_ToRouteInfo(t *testing.T) {
var testCases = []struct {
name string
given Route
whenParams []string
expect RouteInfo
}{
{
name: "ok, no params, with name",
given: Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
},
expect: routeInfo{
method: http.MethodGet,
path: "/test",
params: nil,
name: "test route",
},
},
{
name: "ok, params",
given: Route{
Method: http.MethodGet,
Path: "users/:id/:file", // no slash prefix
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "",
},
whenParams: []string{"id", "file"},
expect: routeInfo{
method: http.MethodGet,
path: "users/:id/:file",
params: []string{"id", "file"},
name: "GET:users/:id/:file",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ri := tc.given.ToRouteInfo(tc.whenParams)
assert.Equal(t, tc.expect, ri)
})
}
}
func TestRoute_ToRoute(t *testing.T) {
route := Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
}
r := route.ToRoute()
assert.Equal(t, r.Method, http.MethodGet)
assert.Equal(t, r.Path, "/test")
assert.NotNil(t, r.Handler)
assert.Nil(t, r.Middlewares)
assert.Equal(t, r.Name, "test route")
}
func TestRoute_ForGroup(t *testing.T) {
route := Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
}
mw := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
return next(c)
}
}
gr := route.ForGroup("/users", []MiddlewareFunc{mw})
r := gr.ToRoute()
assert.Equal(t, r.Method, http.MethodGet)
assert.Equal(t, r.Path, "/users/test")
assert.NotNil(t, r.Handler)
assert.Len(t, r.Middlewares, 1)
assert.Equal(t, r.Name, "test route")
}
func exampleRoutes() Routes {
return Routes{
routeInfo{
method: http.MethodGet,
path: "/users",
params: nil,
name: "GET:/users",
},
routeInfo{
method: http.MethodGet,
path: "/users/:id",
params: []string{"id"},
name: "GET:/users/:id",
},
routeInfo{
method: http.MethodPost,
path: "/users/:id",
params: []string{"id"},
name: "POST:/users/:id",
},
routeInfo{
method: http.MethodDelete,
path: "/groups",
params: nil,
name: "non_unique_name",
},
routeInfo{
method: http.MethodPost,
path: "/groups",
params: nil,
name: "non_unique_name",
},
}
}
func TestRoutes_FindByMethodPath(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenMethod string
whenPath string
expectName string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenMethod: http.MethodGet,
whenPath: "/users/:id",
expectName: "GET:/users/:id",
},
{
name: "nok, not found",
given: exampleRoutes(),
whenMethod: http.MethodPut,
whenPath: "/users/:id",
expectName: "",
expectError: "route not found by method and path",
},
{
name: "nok, not found from nil",
given: nil,
whenMethod: http.MethodGet,
whenPath: "/users/:id",
expectName: "",
expectError: "route not found by method and path",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ri, err := tc.given.FindByMethodPath(tc.whenMethod, tc.whenPath)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
assert.Nil(t, ri)
} else {
assert.NoError(t, err)
}
if tc.expectName != "" {
assert.Equal(t, tc.expectName, ri.Name())
}
})
}
}
func TestRoutes_FilterByMethod(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenMethod string
expectNames []string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenMethod: http.MethodGet,
expectNames: []string{"GET:/users", "GET:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenMethod: http.MethodPut,
expectNames: nil,
expectError: "route not found by method",
},
{
name: "nok, not found from nil",
given: nil,
whenMethod: http.MethodGet,
expectNames: nil,
expectError: "route not found by method",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByMethod(tc.whenMethod)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectNames) > 0 {
assert.Len(t, ris, len(tc.expectNames))
for _, ri := range ris {
assert.Contains(t, tc.expectNames, ri.Name())
}
} else {
assert.Nil(t, ris)
}
})
}
}
func TestRoutes_FilterByPath(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenPath string
expectNames []string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenPath: "/users/:id",
expectNames: []string{"GET:/users/:id", "POST:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenPath: "/",
expectNames: nil,
expectError: "route not found by path",
},
{
name: "nok, not found from nil",
given: nil,
whenPath: "/users/:id",
expectNames: nil,
expectError: "route not found by path",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByPath(tc.whenPath)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectNames) > 0 {
assert.Len(t, ris, len(tc.expectNames))
for _, ri := range ris {
assert.Contains(t, tc.expectNames, ri.Name())
}
} else {
assert.Nil(t, ris)
}
})
}
}
func TestRoutes_FilterByName(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenName string
expectMethodPath []string
expectError string
}{
{
name: "ok, found multiple",
given: exampleRoutes(),
whenName: "non_unique_name",
expectMethodPath: []string{"DELETE:/groups", "POST:/groups"},
},
{
name: "ok, found single",
given: exampleRoutes(),
whenName: "GET:/users/:id",
expectMethodPath: []string{"GET:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenName: "/",
expectMethodPath: nil,
expectError: "route not found by name",
},
{
name: "nok, not found from nil",
given: nil,
whenName: "/users/:id",
expectMethodPath: nil,
expectError: "route not found by name",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByName(tc.whenName)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectMethodPath) > 0 {
assert.Len(t, ris, len(tc.expectMethodPath))
for _, ri := range ris {
assert.Contains(t, tc.expectMethodPath, fmt.Sprintf("%v:%v", ri.Method(), ri.Path()))
}
} else {
assert.Nil(t, ris)
}
})
}
}

821
router.go

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

213
server.go Normal file
View File

@ -0,0 +1,213 @@
package echo
import (
stdContext "context"
"crypto/tls"
"fmt"
"io/fs"
"log"
"net"
"net/http"
"os"
"time"
)
const (
banner = "Echo (v%s). High performance, minimalist Go web framework https://echo.labstack.com"
)
// StartConfig is for creating configured http.Server instance to start serve http(s) requests with given Echo instance
type StartConfig struct {
// Address for the server to listen on (if not using custom listener)
Address string
// ListenerNetwork allows setting listener network (see net.Listen for allowed values)
// Optional: defaults to "tcp"
ListenerNetwork string
// CertFilesystem is file system used to load certificates and keys (if certs/keys are given as paths)
CertFilesystem fs.FS
// DisableHTTP2 disables supports for HTTP2 in TLS server
DisableHTTP2 bool
// HideBanner does not log Echo banner on server startup
HideBanner bool
// HidePort does not log port on server startup
HidePort bool
// GracefulContext is context that completion signals graceful shutdown start
GracefulContext stdContext.Context
// GracefulTimeout is period which server allows listeners to finish serving ongoing requests. If this time is exceeded process is exited
// Defaults to 10 seconds
GracefulTimeout time.Duration
// OnShutdownError allows customization of what happens when (graceful) server Shutdown method returns an error.
// Defaults to calling e.logger.Error(err)
OnShutdownError func(err error)
// TLSConfigFunc allows modifying TLS configuration before listener is created with it.
TLSConfigFunc func(tlsConfig *tls.Config)
// ListenerAddrFunc allows getting listener address before server starts serving requests on listener. Useful when
// address is set as random (`:0`) port.
ListenerAddrFunc func(addr net.Addr)
// BeforeServeFunc allows customizing/accessing server before server starts serving requests on listener.
BeforeServeFunc func(s *http.Server) error
}
// Start starts a HTTP server.
func (sc StartConfig) Start(e *Echo) error {
logger := e.Logger
server := http.Server{
Handler: e,
// NB: all http.Server errors will be logged through Logger.Write calls. We could create writer that wraps
// logger and calls Logger.Error internally when http.Server logs error - atm we will use this naive way.
ErrorLog: log.New(logger, "", 0),
}
var tlsConfig *tls.Config = nil
if sc.TLSConfigFunc != nil {
tlsConfig = &tls.Config{}
configureTLS(&sc, tlsConfig)
sc.TLSConfigFunc(tlsConfig)
}
listener, err := createListener(&sc, tlsConfig)
if err != nil {
return err
}
return serve(&sc, &server, listener, logger)
}
// StartTLS starts a HTTPS server.
// If `certFile` or `keyFile` is `string` the values are treated as file paths.
// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is.
func (sc StartConfig) StartTLS(e *Echo, certFile, keyFile interface{}) error {
logger := e.Logger
s := http.Server{
Handler: e,
// NB: all http.Server errors will be logged through Logger.Write calls. We could create writer that wraps
// logger and calls Logger.Error internally when http.Server logs error - atm we will use this naive way.
ErrorLog: log.New(logger, "", 0),
}
certFs := sc.CertFilesystem
if certFs == nil {
certFs = os.DirFS(".")
}
cert, err := filepathOrContent(certFile, certFs)
if err != nil {
return err
}
key, err := filepathOrContent(keyFile, certFs)
if err != nil {
return err
}
cer, err := tls.X509KeyPair(cert, key)
if err != nil {
return err
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cer}}
configureTLS(&sc, tlsConfig)
if sc.TLSConfigFunc != nil {
sc.TLSConfigFunc(tlsConfig)
}
listener, err := createListener(&sc, tlsConfig)
if err != nil {
return err
}
return serve(&sc, &s, listener, logger)
}
func serve(sc *StartConfig, server *http.Server, listener net.Listener, logger Logger) error {
if sc.BeforeServeFunc != nil {
if err := sc.BeforeServeFunc(server); err != nil {
return err
}
}
startupGreetings(sc, logger, listener)
if sc.GracefulContext != nil {
ctx, cancel := stdContext.WithCancel(sc.GracefulContext)
defer cancel() // make sure this graceful coroutine will end when serve returns by some other means
go gracefulShutdown(ctx, sc, server, logger)
}
return server.Serve(listener)
}
func configureTLS(sc *StartConfig, tlsConfig *tls.Config) {
if !sc.DisableHTTP2 {
tlsConfig.NextProtos = append(tlsConfig.NextProtos, "h2")
}
}
func createListener(sc *StartConfig, tlsConfig *tls.Config) (net.Listener, error) {
listenerNetwork := sc.ListenerNetwork
if listenerNetwork == "" {
listenerNetwork = "tcp"
}
var listener net.Listener
var err error
if tlsConfig != nil {
listener, err = tls.Listen(listenerNetwork, sc.Address, tlsConfig)
} else {
listener, err = net.Listen(listenerNetwork, sc.Address)
}
if err != nil {
return nil, err
}
if sc.ListenerAddrFunc != nil {
sc.ListenerAddrFunc(listener.Addr())
}
return listener, nil
}
func startupGreetings(sc *StartConfig, logger Logger, listener net.Listener) {
if !sc.HideBanner {
bannerText := fmt.Sprintf(banner, Version)
logger.Write([]byte(bannerText))
}
if !sc.HidePort {
logger.Write([]byte(fmt.Sprintf("http(s) server started on %s", listener.Addr())))
}
}
func filepathOrContent(fileOrContent interface{}, certFilesystem fs.FS) (content []byte, err error) {
switch v := fileOrContent.(type) {
case string:
return fs.ReadFile(certFilesystem, v)
case []byte:
return v, nil
default:
return nil, ErrInvalidCertOrKeyType
}
}
func gracefulShutdown(gracefulCtx stdContext.Context, sc *StartConfig, server *http.Server, logger Logger) {
<-gracefulCtx.Done() // wait until shutdown context is closed.
// note: is server if closed by other means this method is still run but is good as no-op
timeout := sc.GracefulTimeout
if timeout == 0 {
timeout = 10 * time.Second
}
shutdownCtx, cancel := stdContext.WithTimeout(stdContext.Background(), timeout)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
// we end up here when listeners are not shut down within given timeout
if sc.OnShutdownError != nil {
sc.OnShutdownError(err)
return
}
logger.Error(fmt.Errorf("failed to shut down server within given timeout: %w", err))
}
}

815
server_test.go Normal file
View File

@ -0,0 +1,815 @@
package echo
import (
"bytes"
stdContext "context"
"crypto/tls"
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"strings"
"testing"
"time"
)
func startOnRandomPort(ctx stdContext.Context, e *Echo) (string, error) {
addrChan := make(chan string)
errCh := make(chan error)
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
return waitForServerStart(addrChan, errCh)
}
func waitForServerStart(addrChan <-chan string, errCh <-chan error) (string, error) {
waitCtx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer cancel()
// wait for addr to arrive
for {
select {
case <-waitCtx.Done():
return "", waitCtx.Err()
case addr := <-addrChan:
return addr, nil
case err := <-errCh:
if err == http.ErrServerClosed { // was closed normally before listener callback was called. should not be possible
return "", nil
}
// failed to start and we did not manage to get even listener part.
return "", err
}
}
}
func doGet(url string) (int, string, error) {
resp, err := http.Get(url)
if err != nil {
return 0, "", err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return resp.StatusCode, "", err
}
return resp.StatusCode, string(body), nil
}
func TestStartConfig_Start(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
// check if server is actually up
code, body, err := doGet(fmt.Sprintf("http://%v/ok", addr))
if err != nil {
assert.NoError(t, err)
return
}
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, "OK", body)
shutdown()
<-errCh // we will be blocking here until server returns from http.Serve
// check if server was stopped
code, body, err = doGet(fmt.Sprintf("http://%v/ok", addr))
assert.Equal(t, 0, code)
assert.Equal(t, "", body)
if err == nil {
t.Errorf("missing error")
return
}
assert.True(t, strings.Contains(err.Error(), "connect: connection refused"))
}
func TestStartConfig_GracefulShutdown(t *testing.T) {
var testCases = []struct {
name string
whenHandlerTakesLonger bool
expectBody string
expectGracefulError string
}{
{
name: "ok, all handlers returns before graceful shutdown deadline",
whenHandlerTakesLonger: false,
expectBody: "OK",
expectGracefulError: "",
},
{
name: "nok, handlers do not returns before graceful shutdown deadline",
whenHandlerTakesLonger: true,
expectBody: "timeout",
expectGracefulError: stdContext.DeadlineExceeded.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
msg := "OK"
if tc.whenHandlerTakesLonger {
time.Sleep(150 * time.Millisecond)
msg = "timeout"
}
return c.String(http.StatusOK, msg)
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 50*time.Millisecond)
defer shutdown()
shutdownErrChan := make(chan error, 1)
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 50 * time.Millisecond,
OnShutdownError: func(err error) {
shutdownErrChan <- err
},
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
code, body, err := doGet(fmt.Sprintf("http://%v/ok", addr))
if err != nil {
assert.NoError(t, err)
return
}
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, tc.expectBody, body)
var shutdownErr error
select {
case shutdownErr = <-shutdownErrChan:
default:
}
if tc.expectGracefulError != "" {
assert.EqualError(t, shutdownErr, tc.expectGracefulError)
} else {
assert.NoError(t, shutdownErr)
}
shutdown()
<-errCh // we will be blocking here until server returns from http.Serve
// check if server was stopped
code, body, err = doGet(fmt.Sprintf("http://%v/ok", addr))
assert.Error(t, err)
if err != nil {
assert.True(t, strings.Contains(err.Error(), "connect: connection refused"))
}
assert.Equal(t, 0, code)
assert.Equal(t, "", body)
})
}
}
func TestStartConfig_Start_withTLSConfigFunc(t *testing.T) {
e := New()
tlsConfigCalled := false
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, errors.New("not_implemented")
}
tlsConfigCalled = true
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.Start(e)
assert.EqualError(t, err, "stop_now")
assert.True(t, tlsConfigCalled)
}
func TestStartConfig_Start_createListenerError(t *testing.T) {
e := New()
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.Start(e)
assert.EqualError(t, err, "tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
}
func TestStartConfig_StartTLS(t *testing.T) {
var testCases = []struct {
name string
addr string
certFile string
keyFile string
expectError string
}{
{
name: "ok",
addr: ":0",
},
{
name: "nok, invalid certFile",
addr: ":0",
certFile: "not existing",
expectError: "open not existing: no such file or directory",
},
{
name: "nok, invalid keyFile",
addr: ":0",
keyFile: "not existing",
expectError: "open not existing: no such file or directory",
},
{
name: "nok, failed to create cert out of certFile and keyFile",
addr: ":0",
keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key
expectError: "tls: found a certificate rather than a key in the PEM for the private key",
},
{
name: "nok, invalid tls address",
addr: "nope",
expectError: "listen tcp: address nope: missing port in address",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
certFile := "_fixture/certs/cert.pem"
if tc.certFile != "" {
certFile = tc.certFile
}
keyFile := "_fixture/certs/key.pem"
if tc.keyFile != "" {
keyFile = tc.keyFile
}
s := &StartConfig{
Address: tc.addr,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, certFile, keyFile)
}()
_, err := waitForServerStart(addrChan, errCh)
if tc.expectError != "" {
if _, ok := err.(*os.PathError); ok {
assert.Error(t, err) // error messages for unix and windows are different. so name only error type here
} else {
assert.EqualError(t, err, tc.expectError)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestStartConfig_StartTLS_withTLSConfigFunc(t *testing.T) {
e := New()
tlsConfigCalled := false
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
assert.Len(t, tlsConfig.Certificates, 1)
tlsConfigCalled = true
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.StartTLS(e, "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
assert.EqualError(t, err, "stop_now")
assert.True(t, tlsConfigCalled)
}
func TestStartConfig_StartTLSAndStart(t *testing.T) {
// We name if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server
e := New()
e.GET("/", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
tlsCtx, tlsShutdown := stdContext.WithTimeout(stdContext.Background(), 100*time.Millisecond)
defer tlsShutdown()
addrTLSChan := make(chan string)
errTLSChan := make(chan error)
go func() {
s := &StartConfig{
Address: ":0",
GracefulContext: tlsCtx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrTLSChan <- addr.String()
},
}
errTLSChan <- s.StartTLS(e, "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
}()
tlsAddr, err := waitForServerStart(addrTLSChan, errTLSChan)
assert.NoError(t, err)
// check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true)
client := &http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}}
res, err := client.Get(fmt.Sprintf("https://%v", tlsAddr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 100*time.Millisecond)
defer shutdown()
addrChan := make(chan string)
errChan := make(chan error)
go func() {
s := &StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errChan <- s.Start(e)
}()
addr, err := waitForServerStart(addrChan, errChan)
assert.NoError(t, err)
// now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS
res, err = client.Get(fmt.Sprintf("http://%v", addr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
// see if HTTPS works after HTTP listener is also added
res, err = client.Get(fmt.Sprintf("https://%v", tlsAddr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestFilepathOrContent(t *testing.T) {
cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
require.NoError(t, err)
key, err := ioutil.ReadFile("_fixture/certs/key.pem")
require.NoError(t, err)
testCases := []struct {
name string
cert interface{}
key interface{}
expectedErr error
}{
{
name: `ValidCertAndKeyFilePath`,
cert: "_fixture/certs/cert.pem",
key: "_fixture/certs/key.pem",
expectedErr: nil,
},
{
name: `ValidCertAndKeyByteString`,
cert: cert,
key: key,
expectedErr: nil,
},
{
name: `InvalidKeyType`,
cert: cert,
key: 1,
expectedErr: ErrInvalidCertOrKeyType,
},
{
name: `InvalidCertType`,
cert: 0,
key: key,
expectedErr: ErrInvalidCertOrKeyType,
},
{
name: `InvalidCertAndKeyTypes`,
cert: 0,
key: 1,
expectedErr: ErrInvalidCertOrKeyType,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
s := &StartConfig{
Address: ":0",
CertFilesystem: os.DirFS("."),
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, tc.cert, tc.key)
}()
_, err := waitForServerStart(addrChan, errCh)
if tc.expectedErr != nil {
assert.EqualError(t, err, tc.expectedErr.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func supportsIPv6() bool {
addrs, _ := net.InterfaceAddrs()
for _, addr := range addrs {
// Check if any interface has local IPv6 assigned
if strings.Contains(addr.String(), "::1") {
return true
}
}
return false
}
func TestStartConfig_WithListenerNetwork(t *testing.T) {
testCases := []struct {
name string
network string
address string
}{
{
name: "tcp ipv4 address",
network: "tcp",
address: "127.0.0.1:1323",
},
{
name: "tcp ipv6 address",
network: "tcp",
address: "[::1]:1323",
},
{
name: "tcp4 ipv4 address",
network: "tcp4",
address: "127.0.0.1:1323",
},
{
name: "tcp6 ipv6 address",
network: "tcp6",
address: "[::1]:1323",
},
}
hasIPv6 := supportsIPv6()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if !hasIPv6 && strings.Contains(tc.address, "::") {
t.Skip("Skipping testing IPv6 for " + tc.address + ", not available")
}
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
s := &StartConfig{
Address: tc.address,
ListenerNetwork: tc.network,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.Start(e)
}()
_, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
code, body, err := doGet(fmt.Sprintf("http://%s/ok", tc.address))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, "OK", body)
})
}
}
func TestStartConfig_WithHideBanner(t *testing.T) {
var testCases = []struct {
name string
hideBanner bool
}{
{
name: "hide banner on startup",
hideBanner: true,
},
{
name: "show banner on startup",
hideBanner: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Logger = &testLogger{output: buf}
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
_, err := waitForServerStart(addrChan, errCh)
errCh <- err
shutdown()
}()
s := &StartConfig{
Address: ":0",
HideBanner: tc.hideBanner,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
if err := s.Start(e); err != http.ErrServerClosed {
assert.NoError(t, err)
}
assert.NoError(t, <-errCh)
contains := strings.Contains(buf.String(), "High performance, minimalist Go web framework")
if tc.hideBanner {
assert.False(t, contains)
} else {
assert.True(t, contains)
}
})
}
}
func TestStartConfig_WithHidePort(t *testing.T) {
var testCases = []struct {
name string
hidePort bool
}{
{
name: "hide port on startup",
hidePort: true,
},
{
name: "show port on startup",
hidePort: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Logger = &testLogger{output: buf}
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error, 1)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
go func() {
_, err := waitForServerStart(addrChan, errCh)
errCh <- err
shutdown()
}()
s := &StartConfig{
Address: ":0",
HidePort: tc.hidePort,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
if err := s.Start(e); err != http.ErrServerClosed {
assert.NoError(t, err)
}
assert.NoError(t, <-errCh)
portMsg := fmt.Sprintf("http(s) server started on")
contains := strings.Contains(buf.String(), portMsg)
if tc.hidePort {
assert.False(t, contains)
} else {
assert.True(t, contains)
}
})
}
}
func TestStartConfig_WithBeforeServeFunc(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
s := &StartConfig{
Address: ":0",
BeforeServeFunc: func(s *http.Server) error {
return errors.New("is called before serve")
},
}
err := s.Start(e)
assert.EqualError(t, err, "is called before serve")
}
func TestWithDisableHTTP2(t *testing.T) {
var testCases = []struct {
name string
disableHTTP2 bool
}{
{
name: "HTTP2 enabled",
disableHTTP2: false,
},
{
name: "HTTP2 disabled",
disableHTTP2: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error, 1)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
certFile := "_fixture/certs/cert.pem"
keyFile := "_fixture/certs/key.pem"
s := &StartConfig{
Address: ":0",
DisableHTTP2: tc.disableHTTP2,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, certFile, keyFile)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
url := fmt.Sprintf("https://%v/ok", addr)
// do ordinary http(s) request
client := &http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}}
res, err := client.Get(url)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
// do HTTP2 request
client.Transport = &http2.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
resp, err := client.Get(url)
if err != nil {
if tc.disableHTTP2 {
assert.True(t, strings.Contains(err.Error(), `http2: unexpected ALPN protocol ""; want "h2"`))
return
}
log.Fatalf("Failed get: %s", err)
}
assert.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatalf("Failed reading response body: %s", err)
}
assert.Equal(t, "OK", string(body))
})
}
}
type testLogger struct {
output io.Writer
}
func (l *testLogger) Write(p []byte) (n int, err error) {
return l.output.Write(p)
}
func (l *testLogger) Printf(format string, args ...interface{}) {
_, _ = l.output.Write([]byte(fmt.Sprintf(format, args...)))
}
func (l *testLogger) Error(err error) {
_, _ = l.output.Write([]byte(err.Error()))
}