mirror of
https://github.com/labstack/echo.git
synced 2024-12-20 19:52:47 +02:00
WIP: logger examples
WIP: make default logger implemented custom writer for jsonlike logs WIP: improve examples WIP: defaultErrorHandler use errors.As to unwrap errors. Update readme WIP: default logger logs json, restore e.Start method WIP: clean router.Match a bit WIP: func types/fields have echo.Context has first element WIP: remove yaml tags as functions etc can not be serialized anyway WIP: change BindPathParams,BindQueryParams,BindHeaders from methods to functions and reverse arguments to be like DefaultBinder.Bind is WIP: improved comments, logger now extracts status from error WIP: go mod tidy WIP: rebase with 4.5.0 WIP: * removed todos. * removed StartAutoTLS and StartH2CServer methods from `StartConfig` * KeyAuth middleware errorhandler can swallow the error and resume next middleware WIP: add RouterConfig.UseEscapedPathForMatching to use escaped path for matching request against routes WIP: FIXMEs WIP: upgrade golang-jwt/jwt to `v4` WIP: refactor http methods to return RouteInfo WIP: refactor static not creating multiple routes WIP: refactor route and middleware adding functions not to return error directly WIP: Use 401 for problematic/missing headers for key auth and JWT middleware (#1552, #1402). > In summary, a 401 Unauthorized response should be used for missing or bad authentication WIP: replace `HTTPError.SetInternal` with `HTTPError.WithInternal` so we could not mutate global error variables WIP: add RouteInfo and RouteMatchType into Context what we could know from in middleware what route was matched and/or type of that match (200/404/405) WIP: make notFoundHandler and methodNotAllowedHandler private. encourage that all errors be handled in Echo.HTTPErrorHandler WIP: server cleanup ideas WIP: routable.ForGroup WIP: note about logger middleware WIP: bind should not default values on second try. use crypto rand for better randomness WIP: router add route as interface and returns info as interface WIP: improve flaky test (remains still flaky) WIP: add notes about bind default values WIP: every route can have their own path params names WIP: routerCreator and different tests WIP: different things WIP: remove route implementation WIP: support custom method types WIP: extractor tests WIP: v5.0.x proposal over v4.4.0
This commit is contained in:
parent
c6f0c667f1
commit
6ef5f77bf2
3
.github/workflows/echo.yml
vendored
3
.github/workflows/echo.yml
vendored
@ -27,7 +27,8 @@ jobs:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
# Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
|
||||
# Echo tests with last four major releases
|
||||
go: [1.14, 1.15, 1.16, 1.17]
|
||||
# except v5 starts from 1.16 until there is last four major releases after that
|
||||
go: [1.16, 1.17]
|
||||
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
|
21
.travis.yml
21
.travis.yml
@ -1,21 +0,0 @@
|
||||
arch:
|
||||
- amd64
|
||||
- ppc64le
|
||||
|
||||
language: go
|
||||
go:
|
||||
- 1.14.x
|
||||
- 1.15.x
|
||||
- tip
|
||||
env:
|
||||
- GO111MODULE=on
|
||||
install:
|
||||
- go get -v golang.org/x/lint/golint
|
||||
script:
|
||||
- golint -set_exit_status ./...
|
||||
- go test -race -coverprofile=coverage.txt -covermode=atomic ./...
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
matrix:
|
||||
allow_failures:
|
||||
- go: tip
|
6
Makefile
6
Makefile
@ -24,11 +24,11 @@ race: ## Run tests with data race detector
|
||||
@go test -race ${PKG_LIST}
|
||||
|
||||
benchmark: ## Run benchmarks
|
||||
@go test -run="-" -bench=".*" ${PKG_LIST}
|
||||
@go test -run="-" -benchmem -bench=".*" ${PKG_LIST}
|
||||
|
||||
help: ## Display this help screen
|
||||
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
goversion ?= "1.15"
|
||||
test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.15
|
||||
goversion ?= "1.16"
|
||||
test_version: ## Run tests inside Docker with given version (defaults to 1.16 oldest supported). Example: make test_version goversion=1.16
|
||||
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"
|
||||
|
10
README.md
10
README.md
@ -12,6 +12,8 @@
|
||||
|
||||
## Supported Go versions
|
||||
|
||||
Echo supports last four major releases. `v5` starts from 1.16 until there is last four major releases after that.
|
||||
|
||||
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
|
||||
Therefore a Go version capable of understanding /vN suffixed imports is required:
|
||||
|
||||
@ -67,8 +69,8 @@ package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/labstack/echo/v5/middleware"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@ -83,7 +85,9 @@ func main() {
|
||||
e.GET("/", hello)
|
||||
|
||||
// Start server
|
||||
e.Logger.Fatal(e.Start(":1323"))
|
||||
if err := e.Start(":1323"); err != http.ErrServerClosed {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handler
|
||||
|
88
bind.go
88
bind.go
@ -11,42 +11,38 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type (
|
||||
// Binder is the interface that wraps the Bind method.
|
||||
Binder interface {
|
||||
Bind(i interface{}, c Context) error
|
||||
}
|
||||
// Binder is the interface that wraps the Bind method.
|
||||
type Binder interface {
|
||||
Bind(c Context, i interface{}) error
|
||||
}
|
||||
|
||||
// DefaultBinder is the default implementation of the Binder interface.
|
||||
DefaultBinder struct{}
|
||||
// DefaultBinder is the default implementation of the Binder interface.
|
||||
type DefaultBinder struct{}
|
||||
|
||||
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
|
||||
// Types that don't implement this, but do implement encoding.TextUnmarshaler
|
||||
// will use that interface instead.
|
||||
BindUnmarshaler interface {
|
||||
// UnmarshalParam decodes and assigns a value from an form or query param.
|
||||
UnmarshalParam(param string) error
|
||||
}
|
||||
)
|
||||
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
|
||||
// Types that don't implement this, but do implement encoding.TextUnmarshaler
|
||||
// will use that interface instead.
|
||||
type BindUnmarshaler interface {
|
||||
// UnmarshalParam decodes and assigns a value from an form or query param.
|
||||
UnmarshalParam(param string) error
|
||||
}
|
||||
|
||||
// BindPathParams binds path params to bindable object
|
||||
func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error {
|
||||
names := c.ParamNames()
|
||||
values := c.ParamValues()
|
||||
func BindPathParams(c Context, i interface{}) error {
|
||||
params := map[string][]string{}
|
||||
for i, name := range names {
|
||||
params[name] = []string{values[i]}
|
||||
for _, param := range c.PathParams() {
|
||||
params[param.Name] = []string{param.Value}
|
||||
}
|
||||
if err := b.bindData(i, params, "param"); err != nil {
|
||||
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
||||
if err := bindData(i, params, "param"); err != nil {
|
||||
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BindQueryParams binds query params to bindable object
|
||||
func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error {
|
||||
if err := b.bindData(i, c.QueryParams(), "query"); err != nil {
|
||||
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
||||
func BindQueryParams(c Context, i interface{}) error {
|
||||
if err := bindData(i, c.QueryParams(), "query"); err != nil {
|
||||
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -56,7 +52,7 @@ func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error {
|
||||
// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
|
||||
// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm
|
||||
// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm
|
||||
func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
|
||||
func BindBody(c Context, i interface{}) (err error) {
|
||||
req := c.Request()
|
||||
if req.ContentLength == 0 {
|
||||
return
|
||||
@ -70,25 +66,25 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
|
||||
case *HTTPError:
|
||||
return err
|
||||
default:
|
||||
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
||||
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
|
||||
}
|
||||
}
|
||||
case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML):
|
||||
if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
|
||||
if ute, ok := err.(*xml.UnsupportedTypeError); ok {
|
||||
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err)
|
||||
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error()))
|
||||
} else if se, ok := err.(*xml.SyntaxError); ok {
|
||||
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err)
|
||||
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error()))
|
||||
}
|
||||
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
||||
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
|
||||
}
|
||||
case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
|
||||
params, err := c.FormParams()
|
||||
if err != nil {
|
||||
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
||||
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
|
||||
}
|
||||
if err = b.bindData(i, params, "form"); err != nil {
|
||||
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
||||
if err = bindData(i, params, "form"); err != nil {
|
||||
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedMediaType
|
||||
@ -98,17 +94,17 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
|
||||
|
||||
// BindHeaders binds HTTP headers to a bindable object
|
||||
func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error {
|
||||
if err := b.bindData(i, c.Request().Header, "header"); err != nil {
|
||||
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
|
||||
if err := bindData(i, c.Request().Header, "header"); err != nil {
|
||||
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bind implements the `Binder#Bind` function.
|
||||
// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous
|
||||
// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
|
||||
func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
|
||||
if err := b.BindPathParams(c, i); err != nil {
|
||||
// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
|
||||
func (b *DefaultBinder) Bind(c Context, i interface{}) (err error) {
|
||||
if err := BindPathParams(c, i); err != nil {
|
||||
return err
|
||||
}
|
||||
// Issue #1670 - Query params are binded only for GET/DELETE and NOT for usual request with body (POST/PUT/PATCH)
|
||||
@ -116,15 +112,15 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
|
||||
// i.e. is `&id=1&lang=en` from URL same as `{"id":100,"lang":"de"}` request body and which one should have priority when binding.
|
||||
// This HTTP method check restores pre v4.1.11 behavior and avoids different problems when query is mixed with body
|
||||
if c.Request().Method == http.MethodGet || c.Request().Method == http.MethodDelete {
|
||||
if err = b.BindQueryParams(c, i); err != nil {
|
||||
if err = BindQueryParams(c, i); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return b.BindBody(c, i)
|
||||
return BindBody(c, i)
|
||||
}
|
||||
|
||||
// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag
|
||||
func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error {
|
||||
func bindData(destination interface{}, data map[string][]string, tag string) error {
|
||||
if destination == nil || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -170,7 +166,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
|
||||
// If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags).
|
||||
// structs that implement BindUnmarshaler are binded only when they have explicit tag
|
||||
if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
|
||||
if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil {
|
||||
if err := bindData(structField.Addr().Interface(), data, tag); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -297,7 +293,7 @@ func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
|
||||
|
||||
func setIntField(value string, bitSize int, field reflect.Value) error {
|
||||
if value == "" {
|
||||
value = "0"
|
||||
return nil
|
||||
}
|
||||
intVal, err := strconv.ParseInt(value, 10, bitSize)
|
||||
if err == nil {
|
||||
@ -308,7 +304,7 @@ func setIntField(value string, bitSize int, field reflect.Value) error {
|
||||
|
||||
func setUintField(value string, bitSize int, field reflect.Value) error {
|
||||
if value == "" {
|
||||
value = "0"
|
||||
return nil
|
||||
}
|
||||
uintVal, err := strconv.ParseUint(value, 10, bitSize)
|
||||
if err == nil {
|
||||
@ -319,7 +315,7 @@ func setUintField(value string, bitSize int, field reflect.Value) error {
|
||||
|
||||
func setBoolField(value string, field reflect.Value) error {
|
||||
if value == "" {
|
||||
value = "false"
|
||||
return nil
|
||||
}
|
||||
boolVal, err := strconv.ParseBool(value)
|
||||
if err == nil {
|
||||
@ -330,7 +326,7 @@ func setBoolField(value string, field reflect.Value) error {
|
||||
|
||||
func setFloatField(value string, bitSize int, field reflect.Value) error {
|
||||
if value == "" {
|
||||
value = "0.0"
|
||||
return nil
|
||||
}
|
||||
floatVal, err := strconv.ParseFloat(value, bitSize)
|
||||
if err == nil {
|
||||
|
242
bind_test.go
242
bind_test.go
@ -300,6 +300,52 @@ func TestBindHeaderParamBadType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBind_CombineQueryWithHeaderParam(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/products/999?length=50&page=10&language=et", nil)
|
||||
req.Header.Set("language", "de")
|
||||
req.Header.Set("length", "99")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.SetPathParams(PathParams{{
|
||||
Name: "id",
|
||||
Value: "999",
|
||||
}})
|
||||
|
||||
type SearchOpts struct {
|
||||
ID int `param:"id"`
|
||||
Length int `query:"length"`
|
||||
Page int `query:"page"`
|
||||
Search string `query:"search"`
|
||||
Language string `query:"language" header:"language"`
|
||||
}
|
||||
|
||||
opts := SearchOpts{
|
||||
Length: 100,
|
||||
Page: 0,
|
||||
Search: "default value",
|
||||
Language: "en",
|
||||
}
|
||||
err := c.Bind(&opts)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 50, opts.Length) // bind from query
|
||||
assert.Equal(t, 10, opts.Page) // bind from query
|
||||
assert.Equal(t, 999, opts.ID) // bind from path param
|
||||
assert.Equal(t, "et", opts.Language) // bind from query
|
||||
assert.Equal(t, "default value", opts.Search) // default value stays
|
||||
|
||||
// make sure another bind will not mess already set values unless there are new values
|
||||
err = (&DefaultBinder{}).BindHeaders(c, &opts)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 50, opts.Length) // does not have tag in struct although header exists
|
||||
assert.Equal(t, 10, opts.Page)
|
||||
assert.Equal(t, 999, opts.ID)
|
||||
assert.Equal(t, "de", opts.Language) // header overwrites now this value
|
||||
assert.Equal(t, "default value", opts.Search)
|
||||
}
|
||||
|
||||
func TestBindUnmarshalParam(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
|
||||
@ -330,7 +376,7 @@ func TestBindUnmarshalParam(t *testing.T) {
|
||||
|
||||
func TestBindUnmarshalText(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
result := struct {
|
||||
@ -406,7 +452,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) {
|
||||
|
||||
func TestBindUnmarshalTextPtr(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
result := struct {
|
||||
@ -439,8 +485,7 @@ func TestBindUnsupportedMediaType(t *testing.T) {
|
||||
func TestBindbindData(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
ts := new(bindTestStruct)
|
||||
b := new(DefaultBinder)
|
||||
err := b.bindData(ts, values, "form")
|
||||
err := bindData(ts, values, "form")
|
||||
a.NoError(err)
|
||||
|
||||
a.Equal(0, ts.I)
|
||||
@ -462,12 +507,15 @@ func TestBindbindData(t *testing.T) {
|
||||
|
||||
func TestBindParam(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(GET, "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.SetPath("/users/:id/:name")
|
||||
c.SetParamNames("id", "name")
|
||||
c.SetParamValues("1", "Jon Snow")
|
||||
cc := c.(EditableContext)
|
||||
cc.SetRouteInfo(routeInfo{path: "/users/:id/:name"})
|
||||
cc.SetPathParams(PathParams{
|
||||
{Name: "id", Value: "1"},
|
||||
{Name: "name", Value: "Jon Snow"},
|
||||
})
|
||||
|
||||
u := new(user)
|
||||
err := c.Bind(u)
|
||||
@ -478,9 +526,11 @@ func TestBindParam(t *testing.T) {
|
||||
|
||||
// Second test for the absence of a param
|
||||
c2 := e.NewContext(req, rec)
|
||||
c2.SetPath("/users/:id")
|
||||
c2.SetParamNames("id")
|
||||
c2.SetParamValues("1")
|
||||
cc2 := c2.(EditableContext)
|
||||
cc2.SetRouteInfo(routeInfo{path: "/users/:id"})
|
||||
cc2.SetPathParams(PathParams{
|
||||
{Name: "id", Value: "1"},
|
||||
})
|
||||
|
||||
u = new(user)
|
||||
err = c2.Bind(u)
|
||||
@ -492,15 +542,17 @@ func TestBindParam(t *testing.T) {
|
||||
// Bind something with param and post data payload
|
||||
body := bytes.NewBufferString(`{ "name": "Jon Snow" }`)
|
||||
e2 := New()
|
||||
req2 := httptest.NewRequest(POST, "/", body)
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req2.Header.Set(HeaderContentType, MIMEApplicationJSON)
|
||||
|
||||
rec2 := httptest.NewRecorder()
|
||||
|
||||
c3 := e2.NewContext(req2, rec2)
|
||||
c3.SetPath("/users/:id")
|
||||
c3.SetParamNames("id")
|
||||
c3.SetParamValues("1")
|
||||
cc3 := c3.(EditableContext)
|
||||
cc3.SetRouteInfo(routeInfo{path: "/users/:id"})
|
||||
cc3.SetPathParams(PathParams{
|
||||
{Name: "id", Value: "1"},
|
||||
})
|
||||
|
||||
u = new(user)
|
||||
err = c3.Bind(u)
|
||||
@ -556,47 +608,115 @@ func TestBindSetWithProperType(t *testing.T) {
|
||||
assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
|
||||
}
|
||||
|
||||
func TestBindSetFields(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
func TestSetIntField(t *testing.T) {
|
||||
ts := new(bindTestStruct)
|
||||
ts.I = 100
|
||||
|
||||
val := reflect.ValueOf(ts).Elem()
|
||||
|
||||
// empty value does nothing to field
|
||||
// in that way we can have default values by setting field value before binding
|
||||
err := setIntField("", 0, val.FieldByName("I"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 100, ts.I)
|
||||
|
||||
// second set with value sets the value
|
||||
err = setIntField("5", 0, val.FieldByName("I"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, ts.I)
|
||||
|
||||
// third set without value does nothing to the value
|
||||
// in that way multiple binds (ala query + header) do not reset fields to 0s
|
||||
err = setIntField("", 0, val.FieldByName("I"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, ts.I)
|
||||
}
|
||||
|
||||
func TestSetUintField(t *testing.T) {
|
||||
ts := new(bindTestStruct)
|
||||
ts.UI = 100
|
||||
|
||||
val := reflect.ValueOf(ts).Elem()
|
||||
|
||||
// empty value does nothing to field
|
||||
// in that way we can have default values by setting field value before binding
|
||||
err := setUintField("", 0, val.FieldByName("UI"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint(100), ts.UI)
|
||||
|
||||
// second set with value sets the value
|
||||
err = setUintField("5", 0, val.FieldByName("UI"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint(5), ts.UI)
|
||||
|
||||
// third set without value does nothing to the value
|
||||
// in that way multiple binds (ala query + header) do not reset fields to 0s
|
||||
err = setUintField("", 0, val.FieldByName("UI"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint(5), ts.UI)
|
||||
}
|
||||
|
||||
func TestSetFloatField(t *testing.T) {
|
||||
ts := new(bindTestStruct)
|
||||
ts.F32 = 100
|
||||
|
||||
val := reflect.ValueOf(ts).Elem()
|
||||
|
||||
// empty value does nothing to field
|
||||
// in that way we can have default values by setting field value before binding
|
||||
err := setFloatField("", 0, val.FieldByName("F32"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, float32(100), ts.F32)
|
||||
|
||||
// second set with value sets the value
|
||||
err = setFloatField("15.5", 0, val.FieldByName("F32"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, float32(15.5), ts.F32)
|
||||
|
||||
// third set without value does nothing to the value
|
||||
// in that way multiple binds (ala query + header) do not reset fields to 0s
|
||||
err = setFloatField("", 0, val.FieldByName("F32"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, float32(15.5), ts.F32)
|
||||
}
|
||||
|
||||
func TestSetBoolField(t *testing.T) {
|
||||
ts := new(bindTestStruct)
|
||||
ts.B = true
|
||||
|
||||
val := reflect.ValueOf(ts).Elem()
|
||||
|
||||
// empty value does nothing to field
|
||||
// in that way we can have default values by setting field value before binding
|
||||
err := setBoolField("", val.FieldByName("B"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ts.B)
|
||||
|
||||
// second set with value sets the value
|
||||
err = setBoolField("true", val.FieldByName("B"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ts.B)
|
||||
|
||||
// third set without value does nothing to the value
|
||||
// in that way multiple binds (ala query + header) do not reset fields to 0s
|
||||
err = setBoolField("", val.FieldByName("B"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ts.B)
|
||||
|
||||
// fourth set to false
|
||||
err = setBoolField("false", val.FieldByName("B"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, false, ts.B)
|
||||
}
|
||||
|
||||
func TestUnmarshalFieldNonPtr(t *testing.T) {
|
||||
ts := new(bindTestStruct)
|
||||
val := reflect.ValueOf(ts).Elem()
|
||||
// Int
|
||||
if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) {
|
||||
assert.Equal(5, ts.I)
|
||||
}
|
||||
if assert.NoError(setIntField("", 0, val.FieldByName("I"))) {
|
||||
assert.Equal(0, ts.I)
|
||||
}
|
||||
|
||||
// Uint
|
||||
if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) {
|
||||
assert.Equal(uint(10), ts.UI)
|
||||
}
|
||||
if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) {
|
||||
assert.Equal(uint(0), ts.UI)
|
||||
}
|
||||
|
||||
// Float
|
||||
if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) {
|
||||
assert.Equal(float32(15.5), ts.F32)
|
||||
}
|
||||
if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) {
|
||||
assert.Equal(float32(0.0), ts.F32)
|
||||
}
|
||||
|
||||
// Bool
|
||||
if assert.NoError(setBoolField("true", val.FieldByName("B"))) {
|
||||
assert.Equal(true, ts.B)
|
||||
}
|
||||
if assert.NoError(setBoolField("", val.FieldByName("B"))) {
|
||||
assert.Equal(false, ts.B)
|
||||
}
|
||||
|
||||
ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T"))
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(ok, true)
|
||||
assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
|
||||
if assert.NoError(t, err) {
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
|
||||
}
|
||||
}
|
||||
|
||||
@ -604,11 +724,10 @@ func BenchmarkBindbindDataWithTags(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
assert := assert.New(b)
|
||||
ts := new(bindTestStructWithTags)
|
||||
binder := new(DefaultBinder)
|
||||
var err error
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
err = binder.bindData(ts, values, "form")
|
||||
err = bindData(ts, values, "form")
|
||||
}
|
||||
assert.NoError(err)
|
||||
assertBindTestStruct(assert, (*bindTestStruct)(ts))
|
||||
@ -840,8 +959,10 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
if !tc.whenNoPathParams {
|
||||
c.SetParamNames("node")
|
||||
c.SetParamValues("node_from_path")
|
||||
cc := c.(EditableContext)
|
||||
cc.SetPathParams(PathParams{
|
||||
{Name: "node", Value: "node_from_path"},
|
||||
})
|
||||
}
|
||||
|
||||
var bindTarget interface{}
|
||||
@ -852,7 +973,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
|
||||
}
|
||||
b := new(DefaultBinder)
|
||||
|
||||
err := b.Bind(bindTarget, c)
|
||||
err := b.Bind(c, bindTarget)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
@ -1021,8 +1142,10 @@ func TestDefaultBinder_BindBody(t *testing.T) {
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
if !tc.whenNoPathParams {
|
||||
c.SetParamNames("node")
|
||||
c.SetParamValues("real_node")
|
||||
cc := c.(EditableContext)
|
||||
cc.SetPathParams(PathParams{
|
||||
{Name: "node", Value: "real_node"},
|
||||
})
|
||||
}
|
||||
|
||||
var bindTarget interface{}
|
||||
@ -1031,9 +1154,8 @@ func TestDefaultBinder_BindBody(t *testing.T) {
|
||||
} else {
|
||||
bindTarget = &Node{}
|
||||
}
|
||||
b := new(DefaultBinder)
|
||||
|
||||
err := b.BindBody(c, bindTarget)
|
||||
err := BindBody(c, bindTarget)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
|
@ -118,10 +118,10 @@ func QueryParamsBinder(c Context) *ValueBinder {
|
||||
func PathParamsBinder(c Context) *ValueBinder {
|
||||
return &ValueBinder{
|
||||
failFast: true,
|
||||
ValueFunc: c.Param,
|
||||
ValueFunc: c.PathParam,
|
||||
ValuesFunc: func(sourceParam string) []string {
|
||||
// path parameter should not have multiple values so getting values does not make sense but lets not error out here
|
||||
value := c.Param(sourceParam)
|
||||
value := c.PathParam(sourceParam)
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
|
@ -30,14 +30,15 @@ func createTestContext15(URL string, body io.Reader, pathParams map[string]strin
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
if len(pathParams) > 0 {
|
||||
names := make([]string, 0)
|
||||
values := make([]string, 0)
|
||||
params := make(PathParams, 0)
|
||||
for name, value := range pathParams {
|
||||
names = append(names, name)
|
||||
values = append(values, value)
|
||||
params = append(params, PathParam{
|
||||
Name: name,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
c.SetParamNames(names...)
|
||||
c.SetParamValues(values...)
|
||||
cc := c.(EditableContext)
|
||||
cc.SetPathParams(params)
|
||||
}
|
||||
|
||||
return c
|
||||
|
@ -25,14 +25,15 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string)
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
if len(pathParams) > 0 {
|
||||
names := make([]string, 0)
|
||||
values := make([]string, 0)
|
||||
params := make(PathParams, 0)
|
||||
for name, value := range pathParams {
|
||||
names = append(names, name)
|
||||
values = append(values, value)
|
||||
params = append(params, PathParam{
|
||||
Name: name,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
c.SetParamNames(names...)
|
||||
c.SetParamValues(values...)
|
||||
cc := c.(EditableContext)
|
||||
cc.SetPathParams(params)
|
||||
}
|
||||
|
||||
return c
|
||||
@ -2643,7 +2644,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) {
|
||||
binder := new(DefaultBinder)
|
||||
for i := 0; i < b.N; i++ {
|
||||
var dest Opts
|
||||
_ = binder.Bind(&dest, c)
|
||||
_ = binder.Bind(c, &dest)
|
||||
}
|
||||
}
|
||||
|
||||
@ -2710,7 +2711,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
|
||||
binder := new(DefaultBinder)
|
||||
for i := 0; i < b.N; i++ {
|
||||
var dest Opts
|
||||
_ = binder.Bind(&dest, c)
|
||||
_ = binder.Bind(c, &dest)
|
||||
if dest.Int64 != 1 {
|
||||
b.Fatalf("int64!=1")
|
||||
}
|
||||
|
433
context.go
433
context.go
@ -3,212 +3,233 @@ package echo
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type (
|
||||
// Context represents the context of the current HTTP request. It holds request and
|
||||
// response objects, path, path parameters, data and registered handler.
|
||||
Context interface {
|
||||
// Request returns `*http.Request`.
|
||||
Request() *http.Request
|
||||
// Context represents the context of the current HTTP request. It holds request and
|
||||
// response objects, path, path parameters, data and registered handler.
|
||||
type Context interface {
|
||||
// Request returns `*http.Request`.
|
||||
Request() *http.Request
|
||||
|
||||
// SetRequest sets `*http.Request`.
|
||||
SetRequest(r *http.Request)
|
||||
// SetRequest sets `*http.Request`.
|
||||
SetRequest(r *http.Request)
|
||||
|
||||
// SetResponse sets `*Response`.
|
||||
SetResponse(r *Response)
|
||||
// SetResponse sets `*Response`.
|
||||
SetResponse(r *Response)
|
||||
|
||||
// Response returns `*Response`.
|
||||
Response() *Response
|
||||
// Response returns `*Response`.
|
||||
Response() *Response
|
||||
|
||||
// IsTLS returns true if HTTP connection is TLS otherwise false.
|
||||
IsTLS() bool
|
||||
// IsTLS returns true if HTTP connection is TLS otherwise false.
|
||||
IsTLS() bool
|
||||
|
||||
// IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
|
||||
IsWebSocket() bool
|
||||
// IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
|
||||
IsWebSocket() bool
|
||||
|
||||
// Scheme returns the HTTP protocol scheme, `http` or `https`.
|
||||
Scheme() string
|
||||
// Scheme returns the HTTP protocol scheme, `http` or `https`.
|
||||
Scheme() string
|
||||
|
||||
// RealIP returns the client's network address based on `X-Forwarded-For`
|
||||
// or `X-Real-IP` request header.
|
||||
// The behavior can be configured using `Echo#IPExtractor`.
|
||||
RealIP() string
|
||||
// RealIP returns the client's network address based on `X-Forwarded-For`
|
||||
// or `X-Real-IP` request header.
|
||||
// The behavior can be configured using `Echo#IPExtractor`.
|
||||
RealIP() string
|
||||
|
||||
// Path returns the registered path for the handler.
|
||||
Path() string
|
||||
// RouteMatchType returns router match type for current context. This helps middlewares to distinguish which type
|
||||
// of match router found and how this request context handler chain could end:
|
||||
// * route match - this path + method had matching route.
|
||||
// * not found - this path did not match any routes enough to be considered match
|
||||
// * method not allowed - path had routes registered but for other method types then current request is
|
||||
// * unknown - initial state for fresh context before router tries to do routing
|
||||
//
|
||||
// Note: for pre-middleware (Echo.Pre) this method result is always RouteMatchUnknown as at point router has not tried
|
||||
// to match request to route.
|
||||
RouteMatchType() RouteMatchType
|
||||
|
||||
// SetPath sets the registered path for the handler.
|
||||
SetPath(p string)
|
||||
// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route.
|
||||
// In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases.
|
||||
RouteInfo() RouteInfo
|
||||
|
||||
// Param returns path parameter by name.
|
||||
Param(name string) string
|
||||
// Path returns the registered path for the handler.
|
||||
Path() string
|
||||
|
||||
// ParamNames returns path parameter names.
|
||||
ParamNames() []string
|
||||
// PathParam returns path parameter by name.
|
||||
PathParam(name string) string
|
||||
|
||||
// SetParamNames sets path parameter names.
|
||||
SetParamNames(names ...string)
|
||||
// PathParams returns path parameter values.
|
||||
PathParams() PathParams
|
||||
|
||||
// ParamValues returns path parameter values.
|
||||
ParamValues() []string
|
||||
// SetPathParams set path parameter for during current request lifecycle.
|
||||
SetPathParams(params PathParams)
|
||||
|
||||
// SetParamValues sets path parameter values.
|
||||
SetParamValues(values ...string)
|
||||
// QueryParam returns the query param for the provided name.
|
||||
QueryParam(name string) string
|
||||
|
||||
// QueryParam returns the query param for the provided name.
|
||||
QueryParam(name string) string
|
||||
// QueryParams returns the query parameters as `url.Values`.
|
||||
QueryParams() url.Values
|
||||
|
||||
// QueryParams returns the query parameters as `url.Values`.
|
||||
QueryParams() url.Values
|
||||
// QueryString returns the URL query string.
|
||||
QueryString() string
|
||||
|
||||
// QueryString returns the URL query string.
|
||||
QueryString() string
|
||||
// FormValue returns the form field value for the provided name.
|
||||
FormValue(name string) string
|
||||
|
||||
// FormValue returns the form field value for the provided name.
|
||||
FormValue(name string) string
|
||||
// FormParams returns the form parameters as `url.Values`.
|
||||
FormParams() (url.Values, error)
|
||||
|
||||
// FormParams returns the form parameters as `url.Values`.
|
||||
FormParams() (url.Values, error)
|
||||
// FormFile returns the multipart form file for the provided name.
|
||||
FormFile(name string) (*multipart.FileHeader, error)
|
||||
|
||||
// FormFile returns the multipart form file for the provided name.
|
||||
FormFile(name string) (*multipart.FileHeader, error)
|
||||
// MultipartForm returns the multipart form.
|
||||
MultipartForm() (*multipart.Form, error)
|
||||
|
||||
// MultipartForm returns the multipart form.
|
||||
MultipartForm() (*multipart.Form, error)
|
||||
// Cookie returns the named cookie provided in the request.
|
||||
Cookie(name string) (*http.Cookie, error)
|
||||
|
||||
// Cookie returns the named cookie provided in the request.
|
||||
Cookie(name string) (*http.Cookie, error)
|
||||
// SetCookie adds a `Set-Cookie` header in HTTP response.
|
||||
SetCookie(cookie *http.Cookie)
|
||||
|
||||
// SetCookie adds a `Set-Cookie` header in HTTP response.
|
||||
SetCookie(cookie *http.Cookie)
|
||||
// Cookies returns the HTTP cookies sent with the request.
|
||||
Cookies() []*http.Cookie
|
||||
|
||||
// Cookies returns the HTTP cookies sent with the request.
|
||||
Cookies() []*http.Cookie
|
||||
// Get retrieves data from the context.
|
||||
Get(key string) interface{}
|
||||
|
||||
// Get retrieves data from the context.
|
||||
Get(key string) interface{}
|
||||
// Set saves data in the context.
|
||||
Set(key string, val interface{})
|
||||
|
||||
// Set saves data in the context.
|
||||
Set(key string, val interface{})
|
||||
// Bind binds the request body into provided type `i`. The default binder
|
||||
// does it based on Content-Type header.
|
||||
Bind(i interface{}) error
|
||||
|
||||
// Bind binds the request body into provided type `i`. The default binder
|
||||
// does it based on Content-Type header.
|
||||
Bind(i interface{}) error
|
||||
// Validate validates provided `i`. It is usually called after `Context#Bind()`.
|
||||
// Validator must be registered using `Echo#Validator`.
|
||||
Validate(i interface{}) error
|
||||
|
||||
// Validate validates provided `i`. It is usually called after `Context#Bind()`.
|
||||
// Validator must be registered using `Echo#Validator`.
|
||||
Validate(i interface{}) error
|
||||
// Render renders a template with data and sends a text/html response with status
|
||||
// code. Renderer must be registered using `Echo.Renderer`.
|
||||
Render(code int, name string, data interface{}) error
|
||||
|
||||
// Render renders a template with data and sends a text/html response with status
|
||||
// code. Renderer must be registered using `Echo.Renderer`.
|
||||
Render(code int, name string, data interface{}) error
|
||||
// HTML sends an HTTP response with status code.
|
||||
HTML(code int, html string) error
|
||||
|
||||
// HTML sends an HTTP response with status code.
|
||||
HTML(code int, html string) error
|
||||
// HTMLBlob sends an HTTP blob response with status code.
|
||||
HTMLBlob(code int, b []byte) error
|
||||
|
||||
// HTMLBlob sends an HTTP blob response with status code.
|
||||
HTMLBlob(code int, b []byte) error
|
||||
// String sends a string response with status code.
|
||||
String(code int, s string) error
|
||||
|
||||
// String sends a string response with status code.
|
||||
String(code int, s string) error
|
||||
// JSON sends a JSON response with status code.
|
||||
JSON(code int, i interface{}) error
|
||||
|
||||
// JSON sends a JSON response with status code.
|
||||
JSON(code int, i interface{}) error
|
||||
// JSONPretty sends a pretty-print JSON with status code.
|
||||
JSONPretty(code int, i interface{}, indent string) error
|
||||
|
||||
// JSONPretty sends a pretty-print JSON with status code.
|
||||
JSONPretty(code int, i interface{}, indent string) error
|
||||
// JSONBlob sends a JSON blob response with status code.
|
||||
JSONBlob(code int, b []byte) error
|
||||
|
||||
// JSONBlob sends a JSON blob response with status code.
|
||||
JSONBlob(code int, b []byte) error
|
||||
// JSONP sends a JSONP response with status code. It uses `callback` to construct
|
||||
// the JSONP payload.
|
||||
JSONP(code int, callback string, i interface{}) error
|
||||
|
||||
// JSONP sends a JSONP response with status code. It uses `callback` to construct
|
||||
// the JSONP payload.
|
||||
JSONP(code int, callback string, i interface{}) error
|
||||
// JSONPBlob sends a JSONP blob response with status code. It uses `callback`
|
||||
// to construct the JSONP payload.
|
||||
JSONPBlob(code int, callback string, b []byte) error
|
||||
|
||||
// JSONPBlob sends a JSONP blob response with status code. It uses `callback`
|
||||
// to construct the JSONP payload.
|
||||
JSONPBlob(code int, callback string, b []byte) error
|
||||
// XML sends an XML response with status code.
|
||||
XML(code int, i interface{}) error
|
||||
|
||||
// XML sends an XML response with status code.
|
||||
XML(code int, i interface{}) error
|
||||
// XMLPretty sends a pretty-print XML with status code.
|
||||
XMLPretty(code int, i interface{}, indent string) error
|
||||
|
||||
// XMLPretty sends a pretty-print XML with status code.
|
||||
XMLPretty(code int, i interface{}, indent string) error
|
||||
// XMLBlob sends an XML blob response with status code.
|
||||
XMLBlob(code int, b []byte) error
|
||||
|
||||
// XMLBlob sends an XML blob response with status code.
|
||||
XMLBlob(code int, b []byte) error
|
||||
// Blob sends a blob response with status code and content type.
|
||||
Blob(code int, contentType string, b []byte) error
|
||||
|
||||
// Blob sends a blob response with status code and content type.
|
||||
Blob(code int, contentType string, b []byte) error
|
||||
// Stream sends a streaming response with status code and content type.
|
||||
Stream(code int, contentType string, r io.Reader) error
|
||||
|
||||
// Stream sends a streaming response with status code and content type.
|
||||
Stream(code int, contentType string, r io.Reader) error
|
||||
// File sends a response with the content of the file.
|
||||
File(file string) error
|
||||
|
||||
// File sends a response with the content of the file.
|
||||
File(file string) error
|
||||
// Attachment sends a response as attachment, prompting client to save the
|
||||
// file.
|
||||
Attachment(file string, name string) error
|
||||
|
||||
// Attachment sends a response as attachment, prompting client to save the
|
||||
// file.
|
||||
Attachment(file string, name string) error
|
||||
// Inline sends a response as inline, opening the file in the browser.
|
||||
Inline(file string, name string) error
|
||||
|
||||
// Inline sends a response as inline, opening the file in the browser.
|
||||
Inline(file string, name string) error
|
||||
// NoContent sends a response with no body and a status code.
|
||||
NoContent(code int) error
|
||||
|
||||
// NoContent sends a response with no body and a status code.
|
||||
NoContent(code int) error
|
||||
// Redirect redirects the request to a provided URL with status code.
|
||||
Redirect(code int, url string) error
|
||||
|
||||
// Redirect redirects the request to a provided URL with status code.
|
||||
Redirect(code int, url string) error
|
||||
// Error invokes the registered HTTP error handler.
|
||||
// NB: Avoid using this method. It is better to return errors so middlewares up in chain could act on returned error.
|
||||
Error(err error)
|
||||
|
||||
// Error invokes the registered HTTP error handler. Generally used by middleware.
|
||||
Error(err error)
|
||||
// Echo returns the `Echo` instance.
|
||||
Echo() *Echo
|
||||
}
|
||||
|
||||
// Handler returns the matched handler by router.
|
||||
Handler() HandlerFunc
|
||||
// EditableContext is additional interface that structure implementing Context must implement. Methods inside this
|
||||
// interface are meant for Echo internal usage (for mainly routing) and should not be used in middlewares.
|
||||
type EditableContext interface {
|
||||
Context
|
||||
|
||||
// SetHandler sets the matched handler by router.
|
||||
SetHandler(h HandlerFunc)
|
||||
// RawPathParams returns raw path pathParams value.
|
||||
RawPathParams() *PathParams
|
||||
|
||||
// Logger returns the `Logger` instance.
|
||||
Logger() Logger
|
||||
// SetRawPathParams replaces any existing param values with new values for this context lifetime (request).
|
||||
SetRawPathParams(params *PathParams)
|
||||
|
||||
// Set the logger
|
||||
SetLogger(l Logger)
|
||||
// SetPath sets the registered path for the handler.
|
||||
SetPath(p string)
|
||||
|
||||
// Echo returns the `Echo` instance.
|
||||
Echo() *Echo
|
||||
// SetRouteMatchType sets the RouteMatchType of router match for this request.
|
||||
SetRouteMatchType(t RouteMatchType)
|
||||
|
||||
// Reset resets the context after request completes. It must be called along
|
||||
// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
|
||||
// See `Echo#ServeHTTP()`
|
||||
Reset(r *http.Request, w http.ResponseWriter)
|
||||
}
|
||||
// SetRouteInfo sets the route info of this request to the context.
|
||||
SetRouteInfo(ri RouteInfo)
|
||||
|
||||
context struct {
|
||||
request *http.Request
|
||||
response *Response
|
||||
path string
|
||||
pnames []string
|
||||
pvalues []string
|
||||
query url.Values
|
||||
handler HandlerFunc
|
||||
store Map
|
||||
echo *Echo
|
||||
logger Logger
|
||||
lock sync.RWMutex
|
||||
}
|
||||
)
|
||||
// Reset resets the context after request completes. It must be called along
|
||||
// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
|
||||
// See `Echo#ServeHTTP()`
|
||||
Reset(r *http.Request, w http.ResponseWriter)
|
||||
}
|
||||
|
||||
type context struct {
|
||||
request *http.Request
|
||||
response *Response
|
||||
|
||||
matchType RouteMatchType
|
||||
route RouteInfo
|
||||
path string
|
||||
|
||||
// pathParams holds path/uri parameters determined by Router. Lifecycle is handled by Echo to reduce allocations.
|
||||
pathParams *PathParams
|
||||
// currentParams hold path parameters set by non-Echo implementation (custom middlewares, handlers) during the lifetime of Request.
|
||||
// Lifecycle is not handle by Echo and could have excess allocations per served Request
|
||||
currentParams PathParams
|
||||
|
||||
query url.Values
|
||||
store Map
|
||||
echo *Echo
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
const (
|
||||
defaultMemory = 32 << 20 // 32 MB
|
||||
@ -296,52 +317,50 @@ func (c *context) SetPath(p string) {
|
||||
c.path = p
|
||||
}
|
||||
|
||||
func (c *context) Param(name string) string {
|
||||
for i, n := range c.pnames {
|
||||
if i < len(c.pvalues) {
|
||||
if n == name {
|
||||
return c.pvalues[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
func (c *context) RouteMatchType() RouteMatchType {
|
||||
return c.matchType
|
||||
}
|
||||
|
||||
func (c *context) ParamNames() []string {
|
||||
return c.pnames
|
||||
func (c *context) SetRouteMatchType(t RouteMatchType) {
|
||||
c.matchType = t
|
||||
}
|
||||
|
||||
func (c *context) SetParamNames(names ...string) {
|
||||
c.pnames = names
|
||||
|
||||
l := len(names)
|
||||
if *c.echo.maxParam < l {
|
||||
*c.echo.maxParam = l
|
||||
}
|
||||
|
||||
if len(c.pvalues) < l {
|
||||
// Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them,
|
||||
// probably those values will be overriden in a Context#SetParamValues
|
||||
newPvalues := make([]string, l)
|
||||
copy(newPvalues, c.pvalues)
|
||||
c.pvalues = newPvalues
|
||||
}
|
||||
func (c *context) RouteInfo() RouteInfo {
|
||||
return c.route
|
||||
}
|
||||
|
||||
func (c *context) ParamValues() []string {
|
||||
return c.pvalues[:len(c.pnames)]
|
||||
func (c *context) SetRouteInfo(ri RouteInfo) {
|
||||
c.route = ri
|
||||
}
|
||||
|
||||
func (c *context) SetParamValues(values ...string) {
|
||||
// NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times
|
||||
// It will brake the Router#Find code
|
||||
limit := len(values)
|
||||
if limit > *c.echo.maxParam {
|
||||
limit = *c.echo.maxParam
|
||||
func (c *context) RawPathParams() *PathParams {
|
||||
return c.pathParams
|
||||
}
|
||||
|
||||
func (c *context) SetRawPathParams(params *PathParams) {
|
||||
c.pathParams = params
|
||||
}
|
||||
|
||||
func (c *context) PathParam(name string) string {
|
||||
if c.currentParams != nil {
|
||||
return c.currentParams.Get(name, "")
|
||||
}
|
||||
for i := 0; i < limit; i++ {
|
||||
c.pvalues[i] = values[i]
|
||||
|
||||
return c.pathParams.Get(name, "")
|
||||
}
|
||||
|
||||
func (c *context) PathParams() PathParams {
|
||||
if c.currentParams != nil {
|
||||
return c.currentParams
|
||||
}
|
||||
|
||||
result := make(PathParams, len(*c.pathParams))
|
||||
copy(result, *c.pathParams)
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *context) SetPathParams(params PathParams) {
|
||||
c.currentParams = params
|
||||
}
|
||||
|
||||
func (c *context) QueryParam(name string) string {
|
||||
@ -422,7 +441,7 @@ func (c *context) Set(key string, val interface{}) {
|
||||
}
|
||||
|
||||
func (c *context) Bind(i interface{}) error {
|
||||
return c.echo.Binder.Bind(i, c)
|
||||
return c.echo.Binder.Bind(c, i)
|
||||
}
|
||||
|
||||
func (c *context) Validate(i interface{}) error {
|
||||
@ -562,27 +581,36 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *context) File(file string) (err error) {
|
||||
f, err := os.Open(file)
|
||||
func (c *context) File(file string) error {
|
||||
return c.FsFile(file, c.echo.Filesystem)
|
||||
}
|
||||
|
||||
func (c *context) FsFile(file string, filesystem fs.FS) error {
|
||||
// FIXME: should we add this method into echo.Context interface?
|
||||
f, err := filesystem.Open(file)
|
||||
if err != nil {
|
||||
return NotFoundHandler(c)
|
||||
return ErrNotFound
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, _ := f.Stat()
|
||||
if fi.IsDir() {
|
||||
file = filepath.Join(file, indexPage)
|
||||
f, err = os.Open(file)
|
||||
f, err = filesystem.Open(file)
|
||||
if err != nil {
|
||||
return NotFoundHandler(c)
|
||||
return ErrNotFound
|
||||
}
|
||||
defer f.Close()
|
||||
if fi, err = f.Stat(); err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
}
|
||||
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f)
|
||||
return
|
||||
ff, ok := f.(io.ReadSeeker)
|
||||
if !ok {
|
||||
return errors.New("file does not implement io.ReadSeeker")
|
||||
}
|
||||
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *context) Attachment(file, name string) error {
|
||||
@ -613,44 +641,23 @@ func (c *context) Redirect(code int, url string) error {
|
||||
}
|
||||
|
||||
func (c *context) Error(err error) {
|
||||
c.echo.HTTPErrorHandler(err, c)
|
||||
c.echo.HTTPErrorHandler(c, err)
|
||||
}
|
||||
|
||||
func (c *context) Echo() *Echo {
|
||||
return c.echo
|
||||
}
|
||||
|
||||
func (c *context) Handler() HandlerFunc {
|
||||
return c.handler
|
||||
}
|
||||
|
||||
func (c *context) SetHandler(h HandlerFunc) {
|
||||
c.handler = h
|
||||
}
|
||||
|
||||
func (c *context) Logger() Logger {
|
||||
res := c.logger
|
||||
if res != nil {
|
||||
return res
|
||||
}
|
||||
return c.echo.Logger
|
||||
}
|
||||
|
||||
func (c *context) SetLogger(l Logger) {
|
||||
c.logger = l
|
||||
}
|
||||
|
||||
func (c *context) Reset(r *http.Request, w http.ResponseWriter) {
|
||||
c.request = r
|
||||
c.response.reset(w)
|
||||
c.query = nil
|
||||
c.handler = NotFoundHandler
|
||||
c.store = nil
|
||||
|
||||
c.matchType = RouteMatchUnknown
|
||||
c.route = nil
|
||||
c.path = ""
|
||||
c.pnames = nil
|
||||
c.logger = nil
|
||||
// NOTE: Don't reset because it has to have length c.echo.maxParam at all times
|
||||
for i := 0; i < *c.echo.maxParam; i++ {
|
||||
c.pvalues[i] = ""
|
||||
}
|
||||
// NOTE: Don't reset because it has to have length c.echo.contextPathParamAllocSize at all times
|
||||
*c.pathParams = (*c.pathParams)[:0]
|
||||
c.currentParams = nil
|
||||
}
|
||||
|
488
context_test.go
488
context_test.go
@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
@ -18,21 +19,19 @@ import (
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/gommon/log"
|
||||
testify "github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type (
|
||||
Template struct {
|
||||
templates *template.Template
|
||||
}
|
||||
)
|
||||
type Template struct {
|
||||
templates *template.Template
|
||||
}
|
||||
|
||||
var testUser = user{1, "Jon Snow"}
|
||||
|
||||
func BenchmarkAllocJSONP(b *testing.B) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
|
||||
e.Logger = &jsonLogger{writer: ioutil.Discard}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec).(*context)
|
||||
|
||||
@ -46,7 +45,8 @@ func BenchmarkAllocJSONP(b *testing.B) {
|
||||
|
||||
func BenchmarkAllocJSON(b *testing.B) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
|
||||
e.Logger = &jsonLogger{writer: ioutil.Discard}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec).(*context)
|
||||
|
||||
@ -60,7 +60,8 @@ func BenchmarkAllocJSON(b *testing.B) {
|
||||
|
||||
func BenchmarkAllocXML(b *testing.B) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
|
||||
e.Logger = &jsonLogger{writer: ioutil.Discard}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec).(*context)
|
||||
|
||||
@ -106,16 +107,14 @@ func TestContext(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec).(*context)
|
||||
|
||||
assert := testify.New(t)
|
||||
|
||||
// Echo
|
||||
assert.Equal(e, c.Echo())
|
||||
assert.Equal(t, e, c.Echo())
|
||||
|
||||
// Request
|
||||
assert.NotNil(c.Request())
|
||||
assert.NotNil(t, c.Request())
|
||||
|
||||
// Response
|
||||
assert.NotNil(c.Response())
|
||||
assert.NotNil(t, c.Response())
|
||||
|
||||
//--------
|
||||
// Render
|
||||
@ -126,23 +125,23 @@ func TestContext(t *testing.T) {
|
||||
}
|
||||
c.echo.Renderer = tmpl
|
||||
err := c.Render(http.StatusOK, "hello", "Jon Snow")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal("Hello, Jon Snow!", rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "Hello, Jon Snow!", rec.Body.String())
|
||||
}
|
||||
|
||||
c.echo.Renderer = nil
|
||||
err = c.Render(http.StatusOK, "hello", "Jon Snow")
|
||||
assert.Error(err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// JSON
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(userJSON+"\n", rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, userJSON+"\n", rec.Body.String())
|
||||
}
|
||||
|
||||
// JSON with "?pretty"
|
||||
@ -150,10 +149,10 @@ func TestContext(t *testing.T) {
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(userJSONPretty+"\n", rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
|
||||
}
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil) // reset
|
||||
|
||||
@ -161,37 +160,37 @@ func TestContext(t *testing.T) {
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(userJSONPretty+"\n", rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
|
||||
}
|
||||
|
||||
// JSON (error)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.JSON(http.StatusOK, make(chan bool))
|
||||
assert.Error(err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// JSONP
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
callback := "callback"
|
||||
err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String())
|
||||
}
|
||||
|
||||
// XML
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.XML(http.StatusOK, user{1, "Jon Snow"})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(xml.Header+userXML, rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, xml.Header+userXML, rec.Body.String())
|
||||
}
|
||||
|
||||
// XML with "?pretty"
|
||||
@ -199,10 +198,10 @@ func TestContext(t *testing.T) {
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.XML(http.StatusOK, user{1, "Jon Snow"})
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(xml.Header+userXMLPretty, rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
|
||||
}
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
@ -210,22 +209,22 @@ func TestContext(t *testing.T) {
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.XML(http.StatusOK, make(chan bool))
|
||||
assert.Error(err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// XML response write error
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
c.response.Writer = responseWriterErr{}
|
||||
err = c.XML(0, 0)
|
||||
testify.Error(t, err)
|
||||
assert.Error(t, err)
|
||||
|
||||
// XMLPretty
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(xml.Header+userXMLPretty, rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
|
||||
}
|
||||
|
||||
t.Run("empty indent", func(t *testing.T) {
|
||||
@ -237,7 +236,6 @@ func TestContext(t *testing.T) {
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
assert := testify.New(t)
|
||||
|
||||
// New JSONBlob with empty indent
|
||||
rec = httptest.NewRecorder()
|
||||
@ -246,16 +244,15 @@ func TestContext(t *testing.T) {
|
||||
enc.SetIndent(emptyIndent, emptyIndent)
|
||||
err = enc.Encode(u)
|
||||
err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(buf.String(), rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, buf.String(), rec.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("xml", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
assert := testify.New(t)
|
||||
|
||||
// New XMLBlob with empty indent
|
||||
rec = httptest.NewRecorder()
|
||||
@ -264,10 +261,10 @@ func TestContext(t *testing.T) {
|
||||
enc.Indent(emptyIndent, emptyIndent)
|
||||
err = enc.Encode(u)
|
||||
err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(xml.Header+buf.String(), rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, xml.Header+buf.String(), rec.Body.String())
|
||||
}
|
||||
})
|
||||
})
|
||||
@ -276,12 +273,12 @@ func TestContext(t *testing.T) {
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
data, err := json.Marshal(user{1, "Jon Snow"})
|
||||
assert.NoError(err)
|
||||
assert.NoError(t, err)
|
||||
err = c.JSONBlob(http.StatusOK, data)
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(userJSON, rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, userJSON, rec.Body.String())
|
||||
}
|
||||
|
||||
// Legacy JSONPBlob
|
||||
@ -289,44 +286,44 @@ func TestContext(t *testing.T) {
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
callback = "callback"
|
||||
data, err = json.Marshal(user{1, "Jon Snow"})
|
||||
assert.NoError(err)
|
||||
assert.NoError(t, err)
|
||||
err = c.JSONPBlob(http.StatusOK, callback, data)
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(callback+"("+userJSON+");", rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, callback+"("+userJSON+");", rec.Body.String())
|
||||
}
|
||||
|
||||
// Legacy XMLBlob
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
data, err = xml.Marshal(user{1, "Jon Snow"})
|
||||
assert.NoError(err)
|
||||
assert.NoError(t, err)
|
||||
err = c.XMLBlob(http.StatusOK, data)
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(xml.Header+userXML, rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, xml.Header+userXML, rec.Body.String())
|
||||
}
|
||||
|
||||
// String
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.String(http.StatusOK, "Hello, World!")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal("Hello, World!", rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, "Hello, World!", rec.Body.String())
|
||||
}
|
||||
|
||||
// HTML
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.HTML(http.StatusOK, "Hello, <strong>World!</strong>")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal("Hello, <strong>World!</strong>", rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, "Hello, <strong>World!</strong>", rec.Body.String())
|
||||
}
|
||||
|
||||
// Stream
|
||||
@ -334,55 +331,55 @@ func TestContext(t *testing.T) {
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
r := strings.NewReader("response from a stream")
|
||||
err = c.Stream(http.StatusOK, "application/octet-stream", r)
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType))
|
||||
assert.Equal("response from a stream", rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, "response from a stream", rec.Body.String())
|
||||
}
|
||||
|
||||
// Attachment
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.Attachment("_fixture/images/walle.png", "walle.png")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
|
||||
assert.Equal(219885, rec.Body.Len())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
|
||||
assert.Equal(t, 219885, rec.Body.Len())
|
||||
}
|
||||
|
||||
// Inline
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
err = c.Inline("_fixture/images/walle.png", "walle.png")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
|
||||
assert.Equal(219885, rec.Body.Len())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
|
||||
assert.Equal(t, 219885, rec.Body.Len())
|
||||
}
|
||||
|
||||
// NoContent
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
c.NoContent(http.StatusOK)
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Error
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec).(*context)
|
||||
c.Error(errors.New("error"))
|
||||
assert.Equal(http.StatusInternalServerError, rec.Code)
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
|
||||
// Reset
|
||||
c.SetParamNames("foo")
|
||||
c.SetParamValues("bar")
|
||||
c.pathParams = &PathParams{
|
||||
{Name: "foo", Value: "bar"},
|
||||
}
|
||||
c.Set("foe", "ban")
|
||||
c.query = url.Values(map[string][]string{"fon": {"baz"}})
|
||||
c.Reset(req, httptest.NewRecorder())
|
||||
assert.Equal(0, len(c.ParamValues()))
|
||||
assert.Equal(0, len(c.ParamNames()))
|
||||
assert.Equal(0, len(c.store))
|
||||
assert.Equal("", c.Path())
|
||||
assert.Equal(0, len(c.QueryParams()))
|
||||
assert.Equal(t, 0, len(c.PathParams()))
|
||||
assert.Equal(t, 0, len(c.store))
|
||||
assert.Equal(t, nil, c.RouteInfo())
|
||||
assert.Equal(t, 0, len(c.QueryParams()))
|
||||
}
|
||||
|
||||
func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
|
||||
@ -392,11 +389,10 @@ func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
|
||||
c := e.NewContext(req, rec).(*context)
|
||||
err := c.JSON(http.StatusCreated, user{1, "Jon Snow"})
|
||||
|
||||
assert := testify.New(t)
|
||||
if assert.NoError(err) {
|
||||
assert.Equal(http.StatusCreated, rec.Code)
|
||||
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(userJSON+"\n", rec.Body.String())
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusCreated, rec.Code)
|
||||
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
|
||||
assert.Equal(t, userJSON+"\n", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@ -407,9 +403,8 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) {
|
||||
c := e.NewContext(req, rec).(*context)
|
||||
err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()})
|
||||
|
||||
assert := testify.New(t)
|
||||
if assert.Error(err) {
|
||||
assert.False(c.response.Committed)
|
||||
if assert.Error(t, err) {
|
||||
assert.False(t, c.response.Committed)
|
||||
}
|
||||
}
|
||||
|
||||
@ -423,22 +418,20 @@ func TestContextCookie(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec).(*context)
|
||||
|
||||
assert := testify.New(t)
|
||||
|
||||
// Read single
|
||||
cookie, err := c.Cookie("theme")
|
||||
if assert.NoError(err) {
|
||||
assert.Equal("theme", cookie.Name)
|
||||
assert.Equal("light", cookie.Value)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, "theme", cookie.Name)
|
||||
assert.Equal(t, "light", cookie.Value)
|
||||
}
|
||||
|
||||
// Read multiple
|
||||
for _, cookie := range c.Cookies() {
|
||||
switch cookie.Name {
|
||||
case "theme":
|
||||
assert.Equal("light", cookie.Value)
|
||||
assert.Equal(t, "light", cookie.Value)
|
||||
case "user":
|
||||
assert.Equal("Jon Snow", cookie.Value)
|
||||
assert.Equal(t, "Jon Snow", cookie.Value)
|
||||
}
|
||||
}
|
||||
|
||||
@ -453,104 +446,95 @@ func TestContextCookie(t *testing.T) {
|
||||
HttpOnly: true,
|
||||
}
|
||||
c.SetCookie(cookie)
|
||||
assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID")
|
||||
assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
|
||||
assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com")
|
||||
assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure")
|
||||
assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly")
|
||||
}
|
||||
|
||||
func TestContextPath(t *testing.T) {
|
||||
e := New()
|
||||
r := e.Router()
|
||||
|
||||
handler := func(c Context) error { return c.String(http.StatusOK, "OK") }
|
||||
|
||||
r.Add(http.MethodGet, "/users/:id", handler)
|
||||
c := e.NewContext(nil, nil)
|
||||
r.Find(http.MethodGet, "/users/1", c)
|
||||
|
||||
assert := testify.New(t)
|
||||
|
||||
assert.Equal("/users/:id", c.Path())
|
||||
|
||||
r.Add(http.MethodGet, "/users/:uid/files/:fid", handler)
|
||||
c = e.NewContext(nil, nil)
|
||||
r.Find(http.MethodGet, "/users/1/files/1", c)
|
||||
assert.Equal("/users/:uid/files/:fid", c.Path())
|
||||
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID")
|
||||
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
|
||||
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com")
|
||||
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure")
|
||||
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
|
||||
}
|
||||
|
||||
func TestContextPathParam(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c := e.NewContext(req, nil)
|
||||
c := e.NewContext(req, nil).(*context)
|
||||
|
||||
params := &PathParams{
|
||||
{Name: "uid", Value: "101"},
|
||||
{Name: "fid", Value: "501"},
|
||||
}
|
||||
// ParamNames
|
||||
c.SetParamNames("uid", "fid")
|
||||
testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames())
|
||||
|
||||
// ParamValues
|
||||
c.SetParamValues("101", "501")
|
||||
testify.EqualValues(t, []string{"101", "501"}, c.ParamValues())
|
||||
c.pathParams = params
|
||||
assert.EqualValues(t, *params, c.PathParams())
|
||||
|
||||
// Param
|
||||
testify.Equal(t, "501", c.Param("fid"))
|
||||
testify.Equal(t, "", c.Param("undefined"))
|
||||
assert.Equal(t, "501", c.PathParam("fid"))
|
||||
assert.Equal(t, "", c.PathParam("undefined"))
|
||||
}
|
||||
|
||||
func TestContextGetAndSetParam(t *testing.T) {
|
||||
e := New()
|
||||
r := e.Router()
|
||||
r.Add(http.MethodGet, "/:foo", func(Context) error { return nil })
|
||||
_, err := r.Add(Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/:foo",
|
||||
Name: "",
|
||||
Handler: func(Context) error { return nil },
|
||||
Middlewares: nil,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
|
||||
c := e.NewContext(req, nil)
|
||||
c.SetParamNames("foo")
|
||||
|
||||
params := &PathParams{{Name: "foo", Value: "101"}}
|
||||
// ParamNames
|
||||
c.(*context).pathParams = params
|
||||
|
||||
// round-trip param values with modification
|
||||
paramVals := c.ParamValues()
|
||||
testify.EqualValues(t, []string{""}, c.ParamValues())
|
||||
paramVals[0] = "bar"
|
||||
c.SetParamValues(paramVals...)
|
||||
testify.EqualValues(t, []string{"bar"}, c.ParamValues())
|
||||
paramVals := c.PathParams()
|
||||
assert.Equal(t, *params, c.PathParams())
|
||||
|
||||
paramVals[0] = PathParam{Name: "xxx", Value: "yyy"} // PathParams() returns copy and modifying it does nothing to context
|
||||
assert.Equal(t, PathParams{{Name: "foo", Value: "101"}}, c.PathParams())
|
||||
|
||||
pathParams := PathParams{
|
||||
{Name: "aaa", Value: "bbb"},
|
||||
{Name: "ccc", Value: "ddd"},
|
||||
}
|
||||
c.SetPathParams(pathParams)
|
||||
assert.Equal(t, pathParams, c.PathParams())
|
||||
|
||||
// shouldn't explode during Reset() afterwards!
|
||||
testify.NotPanics(t, func() {
|
||||
c.Reset(nil, nil)
|
||||
assert.NotPanics(t, func() {
|
||||
c.(EditableContext).Reset(nil, nil)
|
||||
})
|
||||
assert.Equal(t, PathParams{}, c.PathParams())
|
||||
assert.Len(t, *c.(*context).pathParams, 0)
|
||||
assert.Equal(t, cap(*c.(*context).pathParams), 1)
|
||||
}
|
||||
|
||||
// Issue #1655
|
||||
func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) {
|
||||
assert := testify.New(t)
|
||||
|
||||
func TestContext_SetParamNamesShouldNotModifyPathParams(t *testing.T) {
|
||||
e := New()
|
||||
assert.Equal(0, *e.maxParam)
|
||||
c := e.NewContext(nil, nil).(*context)
|
||||
|
||||
expectedOneParam := []string{"one"}
|
||||
expectedTwoParams := []string{"one", "two"}
|
||||
expectedThreeParams := []string{"one", "two", ""}
|
||||
expectedABCParams := []string{"A", "B", "C"}
|
||||
assert.Equal(t, 0, e.contextPathParamAllocSize)
|
||||
expectedTwoParams := &PathParams{
|
||||
{Name: "1", Value: "one"},
|
||||
{Name: "2", Value: "two"},
|
||||
}
|
||||
c.SetRawPathParams(expectedTwoParams)
|
||||
assert.Equal(t, 0, e.contextPathParamAllocSize)
|
||||
assert.Equal(t, *expectedTwoParams, c.PathParams())
|
||||
|
||||
c := e.NewContext(nil, nil)
|
||||
c.SetParamNames("1", "2")
|
||||
c.SetParamValues(expectedTwoParams...)
|
||||
assert.Equal(2, *e.maxParam)
|
||||
assert.EqualValues(expectedTwoParams, c.ParamValues())
|
||||
|
||||
c.SetParamNames("1")
|
||||
assert.Equal(2, *e.maxParam)
|
||||
// Here for backward compatibility the ParamValues remains as they are
|
||||
assert.EqualValues(expectedOneParam, c.ParamValues())
|
||||
|
||||
c.SetParamNames("1", "2", "3")
|
||||
assert.Equal(3, *e.maxParam)
|
||||
// Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam
|
||||
assert.EqualValues(expectedThreeParams, c.ParamValues())
|
||||
|
||||
c.SetParamValues("A", "B", "C", "D")
|
||||
assert.Equal(3, *e.maxParam)
|
||||
// Here D shouldn't be returned
|
||||
assert.EqualValues(expectedABCParams, c.ParamValues())
|
||||
expectedThreeParams := PathParams{
|
||||
{Name: "1", Value: "one"},
|
||||
{Name: "2", Value: "two"},
|
||||
{Name: "3", Value: "three"},
|
||||
}
|
||||
c.SetPathParams(expectedThreeParams)
|
||||
assert.Equal(t, 0, e.contextPathParamAllocSize)
|
||||
assert.Equal(t, expectedThreeParams, c.PathParams())
|
||||
}
|
||||
|
||||
func TestContextFormValue(t *testing.T) {
|
||||
@ -564,13 +548,13 @@ func TestContextFormValue(t *testing.T) {
|
||||
c := e.NewContext(req, nil)
|
||||
|
||||
// FormValue
|
||||
testify.Equal(t, "Jon Snow", c.FormValue("name"))
|
||||
testify.Equal(t, "jon@labstack.com", c.FormValue("email"))
|
||||
assert.Equal(t, "Jon Snow", c.FormValue("name"))
|
||||
assert.Equal(t, "jon@labstack.com", c.FormValue("email"))
|
||||
|
||||
// FormParams
|
||||
params, err := c.FormParams()
|
||||
if testify.NoError(t, err) {
|
||||
testify.Equal(t, url.Values{
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, url.Values{
|
||||
"name": []string{"Jon Snow"},
|
||||
"email": []string{"jon@labstack.com"},
|
||||
}, params)
|
||||
@ -581,8 +565,8 @@ func TestContextFormValue(t *testing.T) {
|
||||
req.Header.Add(HeaderContentType, MIMEMultipartForm)
|
||||
c = e.NewContext(req, nil)
|
||||
params, err = c.FormParams()
|
||||
testify.Nil(t, params)
|
||||
testify.Error(t, err)
|
||||
assert.Nil(t, params)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestContextQueryParam(t *testing.T) {
|
||||
@ -594,11 +578,11 @@ func TestContextQueryParam(t *testing.T) {
|
||||
c := e.NewContext(req, nil)
|
||||
|
||||
// QueryParam
|
||||
testify.Equal(t, "Jon Snow", c.QueryParam("name"))
|
||||
testify.Equal(t, "jon@labstack.com", c.QueryParam("email"))
|
||||
assert.Equal(t, "Jon Snow", c.QueryParam("name"))
|
||||
assert.Equal(t, "jon@labstack.com", c.QueryParam("email"))
|
||||
|
||||
// QueryParams
|
||||
testify.Equal(t, url.Values{
|
||||
assert.Equal(t, url.Values{
|
||||
"name": []string{"Jon Snow"},
|
||||
"email": []string{"jon@labstack.com"},
|
||||
}, c.QueryParams())
|
||||
@ -609,7 +593,7 @@ func TestContextFormFile(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
mr := multipart.NewWriter(buf)
|
||||
w, err := mr.CreateFormFile("file", "test")
|
||||
if testify.NoError(t, err) {
|
||||
if assert.NoError(t, err) {
|
||||
w.Write([]byte("test"))
|
||||
}
|
||||
mr.Close()
|
||||
@ -618,8 +602,8 @@ func TestContextFormFile(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
f, err := c.FormFile("file")
|
||||
if testify.NoError(t, err) {
|
||||
testify.Equal(t, "test", f.Filename)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, "test", f.Filename)
|
||||
}
|
||||
}
|
||||
|
||||
@ -634,8 +618,8 @@ func TestContextMultipartForm(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
f, err := c.MultipartForm()
|
||||
if testify.NoError(t, err) {
|
||||
testify.NotNil(t, f)
|
||||
if assert.NoError(t, err) {
|
||||
assert.NotNil(t, f)
|
||||
}
|
||||
}
|
||||
|
||||
@ -644,16 +628,16 @@ func TestContextRedirect(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
|
||||
testify.Equal(t, http.StatusMovedPermanently, rec.Code)
|
||||
testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
|
||||
testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
|
||||
assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
|
||||
assert.Equal(t, http.StatusMovedPermanently, rec.Code)
|
||||
assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
|
||||
assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
|
||||
}
|
||||
|
||||
func TestContextStore(t *testing.T) {
|
||||
var c Context = new(context)
|
||||
c.Set("name", "Jon Snow")
|
||||
testify.Equal(t, "Jon Snow", c.Get("name"))
|
||||
assert.Equal(t, "Jon Snow", c.Get("name"))
|
||||
}
|
||||
|
||||
func BenchmarkContext_Store(b *testing.B) {
|
||||
@ -671,42 +655,6 @@ func BenchmarkContext_Store(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextHandler(t *testing.T) {
|
||||
e := New()
|
||||
r := e.Router()
|
||||
b := new(bytes.Buffer)
|
||||
|
||||
r.Add(http.MethodGet, "/handler", func(Context) error {
|
||||
_, err := b.Write([]byte("handler"))
|
||||
return err
|
||||
})
|
||||
c := e.NewContext(nil, nil)
|
||||
r.Find(http.MethodGet, "/handler", c)
|
||||
err := c.Handler()(c)
|
||||
testify.Equal(t, "handler", b.String())
|
||||
testify.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestContext_SetHandler(t *testing.T) {
|
||||
var c Context = new(context)
|
||||
|
||||
testify.Nil(t, c.Handler())
|
||||
|
||||
c.SetHandler(func(c Context) error {
|
||||
return nil
|
||||
})
|
||||
testify.NotNil(t, c.Handler())
|
||||
}
|
||||
|
||||
func TestContext_Path(t *testing.T) {
|
||||
path := "/pa/th"
|
||||
|
||||
var c Context = new(context)
|
||||
|
||||
c.SetPath(path)
|
||||
testify.Equal(t, path, c.Path())
|
||||
}
|
||||
|
||||
type validator struct{}
|
||||
|
||||
func (*validator) Validate(i interface{}) error {
|
||||
@ -717,10 +665,10 @@ func TestContext_Validate(t *testing.T) {
|
||||
e := New()
|
||||
c := e.NewContext(nil, nil)
|
||||
|
||||
testify.Error(t, c.Validate(struct{}{}))
|
||||
assert.Error(t, c.Validate(struct{}{}))
|
||||
|
||||
e.Validator = &validator{}
|
||||
testify.NoError(t, c.Validate(struct{}{}))
|
||||
assert.NoError(t, c.Validate(struct{}{}))
|
||||
}
|
||||
|
||||
func TestContext_QueryString(t *testing.T) {
|
||||
@ -728,21 +676,21 @@ func TestContext_QueryString(t *testing.T) {
|
||||
|
||||
queryString := "query=string&var=val"
|
||||
|
||||
req := httptest.NewRequest(GET, "/?"+queryString, nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil)
|
||||
c := e.NewContext(req, nil)
|
||||
|
||||
testify.Equal(t, queryString, c.QueryString())
|
||||
assert.Equal(t, queryString, c.QueryString())
|
||||
}
|
||||
|
||||
func TestContext_Request(t *testing.T) {
|
||||
var c Context = new(context)
|
||||
|
||||
testify.Nil(t, c.Request())
|
||||
assert.Nil(t, c.Request())
|
||||
|
||||
req := httptest.NewRequest(GET, "/path", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/path", nil)
|
||||
c.SetRequest(req)
|
||||
|
||||
testify.Equal(t, req, c.Request())
|
||||
assert.Equal(t, req, c.Request())
|
||||
}
|
||||
|
||||
func TestContext_Scheme(t *testing.T) {
|
||||
@ -799,14 +747,14 @@ func TestContext_Scheme(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
testify.Equal(t, tt.s, tt.c.Scheme())
|
||||
assert.Equal(t, tt.s, tt.c.Scheme())
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext_IsWebSocket(t *testing.T) {
|
||||
tests := []struct {
|
||||
c Context
|
||||
ws testify.BoolAssertionFunc
|
||||
ws assert.BoolAssertionFunc
|
||||
}{
|
||||
{
|
||||
&context{
|
||||
@ -814,7 +762,7 @@ func TestContext_IsWebSocket(t *testing.T) {
|
||||
Header: http.Header{HeaderUpgrade: []string{"websocket"}},
|
||||
},
|
||||
},
|
||||
testify.True,
|
||||
assert.True,
|
||||
},
|
||||
{
|
||||
&context{
|
||||
@ -822,13 +770,13 @@ func TestContext_IsWebSocket(t *testing.T) {
|
||||
Header: http.Header{HeaderUpgrade: []string{"Websocket"}},
|
||||
},
|
||||
},
|
||||
testify.True,
|
||||
assert.True,
|
||||
},
|
||||
{
|
||||
&context{
|
||||
request: &http.Request{},
|
||||
},
|
||||
testify.False,
|
||||
assert.False,
|
||||
},
|
||||
{
|
||||
&context{
|
||||
@ -836,7 +784,7 @@ func TestContext_IsWebSocket(t *testing.T) {
|
||||
Header: http.Header{HeaderUpgrade: []string{"other"}},
|
||||
},
|
||||
},
|
||||
testify.False,
|
||||
assert.False,
|
||||
},
|
||||
}
|
||||
|
||||
@ -849,30 +797,14 @@ func TestContext_IsWebSocket(t *testing.T) {
|
||||
|
||||
func TestContext_Bind(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
|
||||
c := e.NewContext(req, nil)
|
||||
u := new(user)
|
||||
|
||||
req.Header.Add(HeaderContentType, MIMEApplicationJSON)
|
||||
err := c.Bind(u)
|
||||
testify.NoError(t, err)
|
||||
testify.Equal(t, &user{1, "Jon Snow"}, u)
|
||||
}
|
||||
|
||||
func TestContext_Logger(t *testing.T) {
|
||||
e := New()
|
||||
c := e.NewContext(nil, nil)
|
||||
|
||||
log1 := c.Logger()
|
||||
testify.NotNil(t, log1)
|
||||
|
||||
log2 := log.New("echo2")
|
||||
c.SetLogger(log2)
|
||||
testify.Equal(t, log2, c.Logger())
|
||||
|
||||
// Resetting the context returns the initial logger
|
||||
c.Reset(nil, nil)
|
||||
testify.Equal(t, log1, c.Logger())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &user{1, "Jon Snow"}, u)
|
||||
}
|
||||
|
||||
func TestContext_RealIP(t *testing.T) {
|
||||
@ -925,6 +857,6 @@ func TestContext_RealIP(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
testify.Equal(t, tt.s, tt.c.RealIP())
|
||||
assert.Equal(t, tt.s, tt.c.RealIP())
|
||||
}
|
||||
}
|
||||
|
1103
echo_test.go
1103
echo_test.go
File diff suppressed because it is too large
Load Diff
16
go.mod
16
go.mod
@ -1,17 +1,13 @@
|
||||
module github.com/labstack/echo/v4
|
||||
|
||||
go 1.15
|
||||
go 1.16
|
||||
|
||||
require (
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/labstack/gommon v0.3.0
|
||||
github.com/mattn/go-colorable v0.1.8 // indirect
|
||||
github.com/mattn/go-isatty v0.0.14 // indirect
|
||||
github.com/stretchr/testify v1.4.0
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/golang-jwt/jwt/v4 v4.0.0
|
||||
github.com/stretchr/testify v1.7.0
|
||||
github.com/valyala/fasttemplate v1.2.1
|
||||
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5
|
||||
golang.org/x/net v0.0.0-20210913180222-943fd674d43e
|
||||
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect
|
||||
golang.org/x/text v0.3.7 // indirect
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
|
||||
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
)
|
||||
|
48
go.sum
48
go.sum
@ -1,51 +1,29 @@
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0=
|
||||
github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
|
||||
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8=
|
||||
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
||||
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o=
|
||||
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
|
||||
github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4=
|
||||
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ=
|
||||
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210913180222-943fd674d43e h1:+b/22bPvDYt4NPDcy4xAGCmON713ONAWFeY3Z7I3tR8=
|
||||
golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 h1:xrCZDmdtoloIiooiA9q0OQb9r8HejIHYoHGhGCe1pGg=
|
||||
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE=
|
||||
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
174
group.go
174
group.go
@ -4,95 +4,117 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type (
|
||||
// Group is a set of sub-routes for a specified route. It can be used for inner
|
||||
// routes that share a common middleware or functionality that should be separate
|
||||
// from the parent echo instance while still inheriting from it.
|
||||
Group struct {
|
||||
common
|
||||
host string
|
||||
prefix string
|
||||
middleware []MiddlewareFunc
|
||||
echo *Echo
|
||||
}
|
||||
)
|
||||
|
||||
// Use implements `Echo#Use()` for sub-routes within the Group.
|
||||
func (g *Group) Use(middleware ...MiddlewareFunc) {
|
||||
g.middleware = append(g.middleware, middleware...)
|
||||
if len(g.middleware) == 0 {
|
||||
return
|
||||
}
|
||||
// Allow all requests to reach the group as they might get dropped if router
|
||||
// doesn't find a match, making none of the group middleware process.
|
||||
g.Any("", NotFoundHandler)
|
||||
g.Any("/*", NotFoundHandler)
|
||||
// Group is a set of sub-routes for a specified route. It can be used for inner
|
||||
// routes that share a common middleware or functionality that should be separate
|
||||
// from the parent echo instance while still inheriting from it.
|
||||
type Group struct {
|
||||
host string
|
||||
prefix string
|
||||
middleware []MiddlewareFunc
|
||||
echo *Echo
|
||||
}
|
||||
|
||||
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group.
|
||||
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
|
||||
// Use implements `Echo#Use()` for sub-routes within the Group.
|
||||
// Group middlewares are not executed on request when there is no matching route found.
|
||||
func (g *Group) Use(middleware ...MiddlewareFunc) {
|
||||
g.middleware = append(g.middleware, middleware...)
|
||||
}
|
||||
|
||||
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
|
||||
return g.Add(http.MethodConnect, path, h, m...)
|
||||
}
|
||||
|
||||
// DELETE implements `Echo#DELETE()` for sub-routes within the Group.
|
||||
func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
|
||||
// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
|
||||
return g.Add(http.MethodDelete, path, h, m...)
|
||||
}
|
||||
|
||||
// GET implements `Echo#GET()` for sub-routes within the Group.
|
||||
func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
|
||||
// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
|
||||
return g.Add(http.MethodGet, path, h, m...)
|
||||
}
|
||||
|
||||
// HEAD implements `Echo#HEAD()` for sub-routes within the Group.
|
||||
func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
|
||||
// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
|
||||
return g.Add(http.MethodHead, path, h, m...)
|
||||
}
|
||||
|
||||
// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group.
|
||||
func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
|
||||
// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
|
||||
return g.Add(http.MethodOptions, path, h, m...)
|
||||
}
|
||||
|
||||
// PATCH implements `Echo#PATCH()` for sub-routes within the Group.
|
||||
func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
|
||||
// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
|
||||
return g.Add(http.MethodPatch, path, h, m...)
|
||||
}
|
||||
|
||||
// POST implements `Echo#POST()` for sub-routes within the Group.
|
||||
func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
|
||||
// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
|
||||
return g.Add(http.MethodPost, path, h, m...)
|
||||
}
|
||||
|
||||
// PUT implements `Echo#PUT()` for sub-routes within the Group.
|
||||
func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
|
||||
// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
|
||||
return g.Add(http.MethodPut, path, h, m...)
|
||||
}
|
||||
|
||||
// TRACE implements `Echo#TRACE()` for sub-routes within the Group.
|
||||
func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
|
||||
// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
|
||||
return g.Add(http.MethodTrace, path, h, m...)
|
||||
}
|
||||
|
||||
// Any implements `Echo#Any()` for sub-routes within the Group.
|
||||
func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
|
||||
routes := make([]*Route, len(methods))
|
||||
for i, m := range methods {
|
||||
routes[i] = g.Add(m, path, handler, middleware...)
|
||||
// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
|
||||
errs := make([]error, 0)
|
||||
ris := make(Routes, 0)
|
||||
for _, m := range methods {
|
||||
ri, err := g.AddRoute(Route{
|
||||
Method: m,
|
||||
Path: path,
|
||||
Handler: handler,
|
||||
Middlewares: middleware,
|
||||
})
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
ris = append(ris, ri)
|
||||
}
|
||||
return routes
|
||||
if len(errs) > 0 {
|
||||
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
|
||||
}
|
||||
return ris
|
||||
}
|
||||
|
||||
// Match implements `Echo#Match()` for sub-routes within the Group.
|
||||
func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
|
||||
routes := make([]*Route, len(methods))
|
||||
for i, m := range methods {
|
||||
routes[i] = g.Add(m, path, handler, middleware...)
|
||||
// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
|
||||
errs := make([]error, 0)
|
||||
ris := make(Routes, 0)
|
||||
for _, m := range methods {
|
||||
ri, err := g.AddRoute(Route{
|
||||
Method: m,
|
||||
Path: path,
|
||||
Handler: handler,
|
||||
Middlewares: middleware,
|
||||
})
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
ris = append(ris, ri)
|
||||
}
|
||||
return routes
|
||||
if len(errs) > 0 {
|
||||
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
|
||||
}
|
||||
return ris
|
||||
}
|
||||
|
||||
// Group creates a new sub-group with prefix and optional sub-group-level middleware.
|
||||
// Important! Group middlewares are only executed in case there was exact route match and not
|
||||
// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add
|
||||
// a catch-all route `/*` for the group which handler returns always 404
|
||||
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
|
||||
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
|
||||
m = append(m, g.middleware...)
|
||||
@ -102,23 +124,43 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
|
||||
return
|
||||
}
|
||||
|
||||
// Static implements `Echo#Static()` for sub-routes within the Group.
|
||||
func (g *Group) Static(prefix, root string) {
|
||||
g.static(prefix, root, g.GET)
|
||||
// Static implements `Echo#Static()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) Static(prefix, root string, middleware ...MiddlewareFunc) RouteInfo {
|
||||
return g.Add(
|
||||
http.MethodGet,
|
||||
prefix+"*",
|
||||
StaticDirectoryHandler(root, false),
|
||||
middleware...,
|
||||
)
|
||||
}
|
||||
|
||||
// File implements `Echo#File()` for sub-routes within the Group.
|
||||
func (g *Group) File(path, file string) {
|
||||
g.file(path, file, g.GET)
|
||||
// File implements `Echo#File()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
|
||||
handler := func(c Context) error {
|
||||
return c.File(file)
|
||||
}
|
||||
return g.Add(http.MethodGet, path, handler, middleware...)
|
||||
}
|
||||
|
||||
// Add implements `Echo#Add()` for sub-routes within the Group.
|
||||
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
|
||||
// Combine into a new slice to avoid accidentally passing the same slice for
|
||||
// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
|
||||
ri, err := g.AddRoute(Route{
|
||||
Method: method,
|
||||
Path: path,
|
||||
Handler: handler,
|
||||
Middlewares: middleware,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
|
||||
}
|
||||
return ri
|
||||
}
|
||||
|
||||
// AddRoute registers a new Routable with Router
|
||||
func (g *Group) AddRoute(route Routable) (RouteInfo, error) {
|
||||
// Combine middleware into a new slice to avoid accidentally passing the same slice for
|
||||
// multiple routes, which would lead to later add() calls overwriting the
|
||||
// middleware from earlier calls.
|
||||
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
|
||||
m = append(m, g.middleware...)
|
||||
m = append(m, middleware...)
|
||||
return g.echo.add(g.host, method, g.prefix+path, handler, m...)
|
||||
groupRoute := route.ForGroup(g.prefix, append([]MiddlewareFunc{}, g.middleware...))
|
||||
return g.echo.add(g.host, groupRoute)
|
||||
}
|
||||
|
518
group_test.go
518
group_test.go
@ -1,31 +1,68 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TODO: Fix me
|
||||
func TestGroup(t *testing.T) {
|
||||
g := New().Group("/group")
|
||||
h := func(Context) error { return nil }
|
||||
g.CONNECT("/", h)
|
||||
g.DELETE("/", h)
|
||||
g.GET("/", h)
|
||||
g.HEAD("/", h)
|
||||
g.OPTIONS("/", h)
|
||||
g.PATCH("/", h)
|
||||
g.POST("/", h)
|
||||
g.PUT("/", h)
|
||||
g.TRACE("/", h)
|
||||
g.Any("/", h)
|
||||
g.Match([]string{http.MethodGet, http.MethodPost}, "/", h)
|
||||
g.Static("/static", "/tmp")
|
||||
g.File("/walle", "_fixture/images//walle.png")
|
||||
func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
called := false
|
||||
mw := func(next HandlerFunc) HandlerFunc {
|
||||
return func(c Context) error {
|
||||
called = true
|
||||
return c.NoContent(http.StatusTeapot)
|
||||
}
|
||||
}
|
||||
// even though group has middleware it will not be executed when there are no routes under that group
|
||||
_ = e.Group("/group", mw)
|
||||
|
||||
status, body := request(http.MethodGet, "/group/nope", e)
|
||||
assert.Equal(t, http.StatusNotFound, status)
|
||||
assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
|
||||
|
||||
assert.False(t, called)
|
||||
}
|
||||
|
||||
func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
called := false
|
||||
mw := func(next HandlerFunc) HandlerFunc {
|
||||
return func(c Context) error {
|
||||
called = true
|
||||
return c.NoContent(http.StatusTeapot)
|
||||
}
|
||||
}
|
||||
// even though group has middleware and routes when we have no match on some route the middlewares for that
|
||||
// group will not be executed
|
||||
g := e.Group("/group", mw)
|
||||
g.GET("/yes", handlerFunc)
|
||||
|
||||
status, body := request(http.MethodGet, "/group/nope", e)
|
||||
assert.Equal(t, http.StatusNotFound, status)
|
||||
assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
|
||||
|
||||
assert.False(t, called)
|
||||
}
|
||||
|
||||
func TestGroup_multiLevelGroup(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
api := e.Group("/api")
|
||||
users := api.Group("/users")
|
||||
users.GET("/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
|
||||
status, body := request(http.MethodGet, "/api/users/activate", e)
|
||||
assert.Equal(t, http.StatusTeapot, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
|
||||
func TestGroupFile(t *testing.T) {
|
||||
@ -92,11 +129,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
|
||||
}
|
||||
m2 := func(next HandlerFunc) HandlerFunc {
|
||||
return func(c Context) error {
|
||||
return c.String(http.StatusOK, c.Path())
|
||||
return c.String(http.StatusOK, c.RouteInfo().Path())
|
||||
}
|
||||
}
|
||||
h := func(c Context) error {
|
||||
return c.String(http.StatusOK, c.Path())
|
||||
return c.String(http.StatusOK, c.RouteInfo().Path())
|
||||
}
|
||||
g.Use(m1)
|
||||
g.GET("/help", h, m2)
|
||||
@ -119,3 +156,442 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
|
||||
assert.Equal(t, "/*", m)
|
||||
|
||||
}
|
||||
|
||||
func TestGroup_CONNECT(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
users := e.Group("/users")
|
||||
ri := users.CONNECT("/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
|
||||
assert.Equal(t, http.MethodConnect, ri.Method())
|
||||
assert.Equal(t, "/users/activate", ri.Path())
|
||||
assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name())
|
||||
assert.Nil(t, ri.Params())
|
||||
|
||||
status, body := request(http.MethodConnect, "/users/activate", e)
|
||||
assert.Equal(t, http.StatusTeapot, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
|
||||
func TestGroup_DELETE(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
users := e.Group("/users")
|
||||
ri := users.DELETE("/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
|
||||
assert.Equal(t, http.MethodDelete, ri.Method())
|
||||
assert.Equal(t, "/users/activate", ri.Path())
|
||||
assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name())
|
||||
assert.Nil(t, ri.Params())
|
||||
|
||||
status, body := request(http.MethodDelete, "/users/activate", e)
|
||||
assert.Equal(t, http.StatusTeapot, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
|
||||
func TestGroup_HEAD(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
users := e.Group("/users")
|
||||
ri := users.HEAD("/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
|
||||
assert.Equal(t, http.MethodHead, ri.Method())
|
||||
assert.Equal(t, "/users/activate", ri.Path())
|
||||
assert.Equal(t, http.MethodHead+":/users/activate", ri.Name())
|
||||
assert.Nil(t, ri.Params())
|
||||
|
||||
status, body := request(http.MethodHead, "/users/activate", e)
|
||||
assert.Equal(t, http.StatusTeapot, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
|
||||
func TestGroup_OPTIONS(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
users := e.Group("/users")
|
||||
ri := users.OPTIONS("/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
|
||||
assert.Equal(t, http.MethodOptions, ri.Method())
|
||||
assert.Equal(t, "/users/activate", ri.Path())
|
||||
assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name())
|
||||
assert.Nil(t, ri.Params())
|
||||
|
||||
status, body := request(http.MethodOptions, "/users/activate", e)
|
||||
assert.Equal(t, http.StatusTeapot, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
|
||||
func TestGroup_PATCH(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
users := e.Group("/users")
|
||||
ri := users.PATCH("/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
|
||||
assert.Equal(t, http.MethodPatch, ri.Method())
|
||||
assert.Equal(t, "/users/activate", ri.Path())
|
||||
assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name())
|
||||
assert.Nil(t, ri.Params())
|
||||
|
||||
status, body := request(http.MethodPatch, "/users/activate", e)
|
||||
assert.Equal(t, http.StatusTeapot, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
|
||||
func TestGroup_POST(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
users := e.Group("/users")
|
||||
ri := users.POST("/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
|
||||
assert.Equal(t, http.MethodPost, ri.Method())
|
||||
assert.Equal(t, "/users/activate", ri.Path())
|
||||
assert.Equal(t, http.MethodPost+":/users/activate", ri.Name())
|
||||
assert.Nil(t, ri.Params())
|
||||
|
||||
status, body := request(http.MethodPost, "/users/activate", e)
|
||||
assert.Equal(t, http.StatusTeapot, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
|
||||
func TestGroup_PUT(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
users := e.Group("/users")
|
||||
ri := users.PUT("/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
|
||||
assert.Equal(t, http.MethodPut, ri.Method())
|
||||
assert.Equal(t, "/users/activate", ri.Path())
|
||||
assert.Equal(t, http.MethodPut+":/users/activate", ri.Name())
|
||||
assert.Nil(t, ri.Params())
|
||||
|
||||
status, body := request(http.MethodPut, "/users/activate", e)
|
||||
assert.Equal(t, http.StatusTeapot, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
|
||||
func TestGroup_TRACE(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
users := e.Group("/users")
|
||||
ri := users.TRACE("/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
|
||||
assert.Equal(t, http.MethodTrace, ri.Method())
|
||||
assert.Equal(t, "/users/activate", ri.Path())
|
||||
assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name())
|
||||
assert.Nil(t, ri.Params())
|
||||
|
||||
status, body := request(http.MethodTrace, "/users/activate", e)
|
||||
assert.Equal(t, http.StatusTeapot, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
|
||||
func TestGroup_Any(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
users := e.Group("/users")
|
||||
ris := users.Any("/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
assert.Len(t, ris, 11)
|
||||
|
||||
for _, m := range methods {
|
||||
status, body := request(m, "/users/activate", e)
|
||||
assert.Equal(t, http.StatusTeapot, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_AnyWithErrors(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
users := e.Group("/users")
|
||||
users.GET("/activate", func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
errs := func() (errs []error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if tmpErr, ok := r.([]error); ok {
|
||||
errs = tmpErr
|
||||
return
|
||||
}
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
users.Any("/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
return nil
|
||||
}()
|
||||
assert.Len(t, errs, 1)
|
||||
assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
|
||||
|
||||
for _, m := range methods {
|
||||
status, body := request(m, "/users/activate", e)
|
||||
|
||||
expect := http.StatusTeapot
|
||||
if m == http.MethodGet {
|
||||
expect = http.StatusOK
|
||||
}
|
||||
assert.Equal(t, expect, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_Match(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
myMethods := []string{http.MethodGet, http.MethodPost}
|
||||
users := e.Group("/users")
|
||||
ris := users.Match(myMethods, "/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
assert.Len(t, ris, 2)
|
||||
|
||||
for _, m := range myMethods {
|
||||
status, body := request(m, "/users/activate", e)
|
||||
assert.Equal(t, http.StatusTeapot, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_MatchWithErrors(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
users := e.Group("/users")
|
||||
users.GET("/activate", func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
myMethods := []string{http.MethodGet, http.MethodPost}
|
||||
|
||||
errs := func() (errs []error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if tmpErr, ok := r.([]error); ok {
|
||||
errs = tmpErr
|
||||
return
|
||||
}
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
users.Match(myMethods, "/activate", func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
})
|
||||
return nil
|
||||
}()
|
||||
assert.Len(t, errs, 1)
|
||||
assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
|
||||
|
||||
for _, m := range myMethods {
|
||||
status, body := request(m, "/users/activate", e)
|
||||
|
||||
expect := http.StatusTeapot
|
||||
if m == http.MethodGet {
|
||||
expect = http.StatusOK
|
||||
}
|
||||
assert.Equal(t, expect, status)
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_Static(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
g := e.Group("/books")
|
||||
ri := g.Static("/download", "_fixture")
|
||||
assert.Equal(t, http.MethodGet, ri.Method())
|
||||
assert.Equal(t, "/books/download*", ri.Path())
|
||||
assert.Equal(t, "GET:/books/download*", ri.Name())
|
||||
assert.Equal(t, []string{"*"}, ri.Params())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
body := rec.Body.String()
|
||||
assert.True(t, strings.HasPrefix(body, "<!doctype html>"))
|
||||
}
|
||||
|
||||
func TestGroup_StaticMultiTest(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenPrefix string
|
||||
givenRoot string
|
||||
whenURL string
|
||||
expectStatus int
|
||||
expectHeaderLocation string
|
||||
expectBodyStartsWith string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
givenPrefix: "/images",
|
||||
givenRoot: "_fixture/images",
|
||||
whenURL: "/test/images/walle.png",
|
||||
expectStatus: http.StatusOK,
|
||||
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
|
||||
},
|
||||
{
|
||||
name: "ok, without prefix",
|
||||
givenPrefix: "",
|
||||
givenRoot: "_fixture/images",
|
||||
whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png`
|
||||
expectStatus: http.StatusOK,
|
||||
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
|
||||
},
|
||||
{
|
||||
name: "nok, without prefix does not serve dir index",
|
||||
givenPrefix: "",
|
||||
givenRoot: "_fixture/images",
|
||||
whenURL: "/test/", // `/test` + `*` creates route `/test*`
|
||||
expectStatus: http.StatusNotFound,
|
||||
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
},
|
||||
{
|
||||
name: "No file",
|
||||
givenPrefix: "/images",
|
||||
givenRoot: "_fixture/scripts",
|
||||
whenURL: "/test/images/bolt.png",
|
||||
expectStatus: http.StatusNotFound,
|
||||
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
},
|
||||
{
|
||||
name: "Directory",
|
||||
givenPrefix: "/images",
|
||||
givenRoot: "_fixture/images",
|
||||
whenURL: "/test/images/",
|
||||
expectStatus: http.StatusNotFound,
|
||||
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
},
|
||||
{
|
||||
name: "Directory Redirect",
|
||||
givenPrefix: "/",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/test/folder",
|
||||
expectStatus: http.StatusMovedPermanently,
|
||||
expectHeaderLocation: "/test/folder/",
|
||||
expectBodyStartsWith: "",
|
||||
},
|
||||
{
|
||||
name: "Directory Redirect with non-root path",
|
||||
givenPrefix: "/static",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/test/static",
|
||||
expectStatus: http.StatusMovedPermanently,
|
||||
expectHeaderLocation: "/test/static/",
|
||||
expectBodyStartsWith: "",
|
||||
},
|
||||
{
|
||||
name: "Prefixed directory 404 (request URL without slash)",
|
||||
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/test/folder", // no trailing slash
|
||||
expectStatus: http.StatusNotFound,
|
||||
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
},
|
||||
{
|
||||
name: "Prefixed directory redirect (without slash redirect to slash)",
|
||||
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/test/folder", // no trailing slash
|
||||
expectStatus: http.StatusMovedPermanently,
|
||||
expectHeaderLocation: "/test/folder/",
|
||||
expectBodyStartsWith: "",
|
||||
},
|
||||
{
|
||||
name: "Directory with index.html",
|
||||
givenPrefix: "/",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/test/",
|
||||
expectStatus: http.StatusOK,
|
||||
expectBodyStartsWith: "<!doctype html>",
|
||||
},
|
||||
{
|
||||
name: "Prefixed directory with index.html (prefix ending with slash)",
|
||||
givenPrefix: "/assets/",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/test/assets/",
|
||||
expectStatus: http.StatusOK,
|
||||
expectBodyStartsWith: "<!doctype html>",
|
||||
},
|
||||
{
|
||||
name: "Prefixed directory with index.html (prefix ending without slash)",
|
||||
givenPrefix: "/assets",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/test/assets/",
|
||||
expectStatus: http.StatusOK,
|
||||
expectBodyStartsWith: "<!doctype html>",
|
||||
},
|
||||
{
|
||||
name: "Sub-directory with index.html",
|
||||
givenPrefix: "/",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/test/folder/",
|
||||
expectStatus: http.StatusOK,
|
||||
expectBodyStartsWith: "<!doctype html>",
|
||||
},
|
||||
{
|
||||
name: "do not allow directory traversal (backslash - windows separator)",
|
||||
givenPrefix: "/",
|
||||
givenRoot: "_fixture/",
|
||||
whenURL: `/test/..\\middleware/basic_auth.go`,
|
||||
expectStatus: http.StatusNotFound,
|
||||
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
},
|
||||
{
|
||||
name: "do not allow directory traversal (slash - unix separator)",
|
||||
givenPrefix: "/",
|
||||
givenRoot: "_fixture/",
|
||||
whenURL: `/test/../middleware/basic_auth.go`,
|
||||
expectStatus: http.StatusNotFound,
|
||||
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
g := e.Group("/test")
|
||||
g.Static(tc.givenPrefix, tc.givenRoot)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, tc.expectStatus, rec.Code)
|
||||
body := rec.Body.String()
|
||||
if tc.expectBodyStartsWith != "" {
|
||||
assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
|
||||
} else {
|
||||
assert.Equal(t, "", body)
|
||||
}
|
||||
|
||||
if tc.expectHeaderLocation != "" {
|
||||
assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
|
||||
} else {
|
||||
_, ok := rec.Result().Header["Location"]
|
||||
assert.False(t, ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
74
httperror.go
Normal file
74
httperror.go
Normal file
@ -0,0 +1,74 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType)
|
||||
ErrNotFound = NewHTTPError(http.StatusNotFound)
|
||||
ErrUnauthorized = NewHTTPError(http.StatusUnauthorized)
|
||||
ErrForbidden = NewHTTPError(http.StatusForbidden)
|
||||
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
|
||||
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
|
||||
ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests)
|
||||
ErrBadRequest = NewHTTPError(http.StatusBadRequest)
|
||||
ErrBadGateway = NewHTTPError(http.StatusBadGateway)
|
||||
ErrInternalServerError = NewHTTPError(http.StatusInternalServerError)
|
||||
ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout)
|
||||
ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable)
|
||||
ErrValidatorNotRegistered = errors.New("validator not registered")
|
||||
ErrRendererNotRegistered = errors.New("renderer not registered")
|
||||
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
|
||||
ErrCookieNotFound = errors.New("cookie not found")
|
||||
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
|
||||
ErrInvalidListenerNetwork = errors.New("invalid listener network")
|
||||
)
|
||||
|
||||
// HTTPError represents an error that occurred while handling a request.
|
||||
type HTTPError struct {
|
||||
Code int `json:"-"`
|
||||
Message interface{} `json:"message"`
|
||||
Internal error `json:"-"` // Stores the error returned by an external dependency
|
||||
}
|
||||
|
||||
// NewHTTPError creates a new HTTPError instance.
|
||||
func NewHTTPError(code int, message ...interface{}) *HTTPError { // FIXME: this need cleanup - why vararg if [0] is only used?
|
||||
he := &HTTPError{Code: code, Message: http.StatusText(code)}
|
||||
if len(message) > 0 {
|
||||
he.Message = message[0]
|
||||
}
|
||||
return he
|
||||
}
|
||||
|
||||
// NewHTTPErrorWithInternal creates a new HTTPError instance with internal error set.
|
||||
func NewHTTPErrorWithInternal(code int, internalError error, message ...interface{}) *HTTPError {
|
||||
he := NewHTTPError(code, message...)
|
||||
he.Internal = internalError
|
||||
return he
|
||||
}
|
||||
|
||||
// Error makes it compatible with `error` interface.
|
||||
func (he *HTTPError) Error() string {
|
||||
if he.Internal == nil {
|
||||
return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message)
|
||||
}
|
||||
return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal)
|
||||
}
|
||||
|
||||
// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field
|
||||
func (he *HTTPError) WithInternal(err error) *HTTPError {
|
||||
return &HTTPError{
|
||||
Code: he.Code,
|
||||
Message: he.Message,
|
||||
Internal: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Unwrap satisfies the Go 1.13 error wrapper interface.
|
||||
func (he *HTTPError) Unwrap() error {
|
||||
return he.Internal
|
||||
}
|
52
httperror_test.go
Normal file
52
httperror_test.go
Normal file
@ -0,0 +1,52 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHTTPError(t *testing.T) {
|
||||
t.Run("non-internal", func(t *testing.T) {
|
||||
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
|
||||
"code": 12,
|
||||
})
|
||||
|
||||
assert.Equal(t, "code=400, message=map[code:12]", err.Error())
|
||||
})
|
||||
t.Run("internal", func(t *testing.T) {
|
||||
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
|
||||
"code": 12,
|
||||
})
|
||||
err = err.WithInternal(errors.New("internal error"))
|
||||
assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewHTTPErrorWithInternal(t *testing.T) {
|
||||
he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"), "test message")
|
||||
assert.Equal(t, "code=400, message=test message, internal=test", he.Error())
|
||||
}
|
||||
|
||||
func TestNewHTTPErrorWithInternal_noCustomMessage(t *testing.T) {
|
||||
he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"))
|
||||
assert.Equal(t, "code=400, message=Bad Request, internal=test", he.Error())
|
||||
}
|
||||
|
||||
func TestHTTPError_Unwrap(t *testing.T) {
|
||||
t.Run("non-internal", func(t *testing.T) {
|
||||
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
|
||||
"code": 12,
|
||||
})
|
||||
|
||||
assert.Nil(t, errors.Unwrap(err))
|
||||
})
|
||||
t.Run("internal", func(t *testing.T) {
|
||||
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
|
||||
"code": 12,
|
||||
})
|
||||
err = err.WithInternal(errors.New("internal error"))
|
||||
assert.Equal(t, "internal error", errors.Unwrap(err).Error())
|
||||
})
|
||||
}
|
11
json.go
11
json.go
@ -23,9 +23,16 @@ func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string
|
||||
func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error {
|
||||
err := json.NewDecoder(c.Request().Body).Decode(i)
|
||||
if ute, ok := err.(*json.UnmarshalTypeError); ok {
|
||||
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err)
|
||||
return NewHTTPErrorWithInternal(
|
||||
http.StatusBadRequest,
|
||||
err,
|
||||
fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset),
|
||||
)
|
||||
} else if se, ok := err.(*json.SyntaxError); ok {
|
||||
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err)
|
||||
return NewHTTPErrorWithInternal(http.StatusBadRequest,
|
||||
err,
|
||||
fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error()),
|
||||
)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
168
log.go
168
log.go
@ -1,41 +1,141 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/labstack/gommon/log"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
// Logger defines the logging interface.
|
||||
Logger interface {
|
||||
Output() io.Writer
|
||||
SetOutput(w io.Writer)
|
||||
Prefix() string
|
||||
SetPrefix(p string)
|
||||
Level() log.Lvl
|
||||
SetLevel(v log.Lvl)
|
||||
SetHeader(h string)
|
||||
Print(i ...interface{})
|
||||
Printf(format string, args ...interface{})
|
||||
Printj(j log.JSON)
|
||||
Debug(i ...interface{})
|
||||
Debugf(format string, args ...interface{})
|
||||
Debugj(j log.JSON)
|
||||
Info(i ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Infoj(j log.JSON)
|
||||
Warn(i ...interface{})
|
||||
Warnf(format string, args ...interface{})
|
||||
Warnj(j log.JSON)
|
||||
Error(i ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
Errorj(j log.JSON)
|
||||
Fatal(i ...interface{})
|
||||
Fatalj(j log.JSON)
|
||||
Fatalf(format string, args ...interface{})
|
||||
Panic(i ...interface{})
|
||||
Panicj(j log.JSON)
|
||||
Panicf(format string, args ...interface{})
|
||||
//-----------------------------------------------------------------------------
|
||||
// Example for Zap (https://github.com/uber-go/zap)
|
||||
//func main() {
|
||||
// e := echo.New()
|
||||
// logger, _ := zap.NewProduction()
|
||||
// e.Logger = &ZapLogger{logger: logger}
|
||||
//}
|
||||
//type ZapLogger struct {
|
||||
// logger *zap.Logger
|
||||
//}
|
||||
//
|
||||
//func (l *ZapLogger) Write(p []byte) (n int, err error) {
|
||||
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
|
||||
// l.logger.Info(string(p), zap.String("subsystem", "echo")) // naively log everything as string message.
|
||||
// return len(p), nil
|
||||
//}
|
||||
//
|
||||
//func (l *ZapLogger) Error(err error) {
|
||||
// l.logger.Error(err.Error(), zap.Error(err), zap.String("subsystem", "echo"))
|
||||
//}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Example for Zerolog (https://github.com/rs/zerolog)
|
||||
//func main() {
|
||||
// e := echo.New()
|
||||
// logger := zerolog.New(os.Stdout)
|
||||
// e.Logger = &ZeroLogger{logger: &logger}
|
||||
//}
|
||||
//
|
||||
//type ZeroLogger struct {
|
||||
// logger *zerolog.Logger
|
||||
//}
|
||||
//
|
||||
//func (l *ZeroLogger) Write(p []byte) (n int, err error) {
|
||||
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
|
||||
// l.logger.Info().Str("subsystem", "echo").Msg(string(p)) // naively log everything as string message.
|
||||
// return len(p), nil
|
||||
//}
|
||||
//
|
||||
//func (l *ZeroLogger) Error(err error) {
|
||||
// l.logger.Error().Str("subsystem", "echo").Err(err).Msg(err.Error())
|
||||
//}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Example for Logrus (https://github.com/sirupsen/logrus)
|
||||
//func main() {
|
||||
// e := echo.New()
|
||||
// e.Logger = &LogrusLogger{logger: logrus.New()}
|
||||
//}
|
||||
//
|
||||
//type LogrusLogger struct {
|
||||
// logger *logrus.Logger
|
||||
//}
|
||||
//
|
||||
//func (l *LogrusLogger) Write(p []byte) (n int, err error) {
|
||||
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
|
||||
// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Info(string(p)) // naively log everything as string message.
|
||||
// return len(p), nil
|
||||
//}
|
||||
//
|
||||
//func (l *LogrusLogger) Error(err error) {
|
||||
// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Error(err)
|
||||
//}
|
||||
|
||||
// Logger defines the logging interface that Echo uses internally in few places.
|
||||
// For logging in handlers use your own logger instance (dependency injected or package/public variable) from logging framework of your choice.
|
||||
type Logger interface {
|
||||
// Write provides writer interface for http.Server `ErrorLog` and for logging startup messages.
|
||||
// `http.Server.ErrorLog` logs errors from accepting connections, unexpected behavior from handlers,
|
||||
// and underlying FileSystem errors.
|
||||
// `logger` middleware will use this method to write its JSON payload.
|
||||
Write(p []byte) (n int, err error)
|
||||
// Error logs the error
|
||||
Error(err error)
|
||||
}
|
||||
|
||||
// jsonLogger is similar logger formatting implementation as `v4` had. It is not particularly fast or efficient. Only
|
||||
// goal it to exist is to have somewhat backwards compatibility with `v4` for Echo internals logging formatting.
|
||||
// It is not meant for logging in handlers/middlewares. Use some real logging library for those cases.
|
||||
type jsonLogger struct {
|
||||
writer io.Writer
|
||||
bufferPool sync.Pool
|
||||
|
||||
timeNow func() time.Time
|
||||
}
|
||||
|
||||
func newJSONLogger(writer io.Writer) *jsonLogger {
|
||||
return &jsonLogger{
|
||||
writer: writer,
|
||||
bufferPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bytes.NewBuffer(make([]byte, 256))
|
||||
},
|
||||
},
|
||||
timeNow: time.Now,
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
func (l *jsonLogger) Write(p []byte) (n int, err error) {
|
||||
pLen := len(p)
|
||||
if pLen >= 2 && // naively try to avoid JSON values to be wrapped into message
|
||||
(p[0] == '{' && p[pLen-2] == '}' && p[pLen-1] == '\n') ||
|
||||
(p[0] == '{' && p[pLen-1] == '}') {
|
||||
return l.writer.Write(p)
|
||||
}
|
||||
// we log with WARN level as we have no idea what that message level should be. From Echo perspective this method is
|
||||
// called when we pass Echo logger to http.Server.ErrorLog and there are problems inside http.Server - which probably
|
||||
// deserves at least WARN level.
|
||||
return l.printf("WARN", string(p))
|
||||
}
|
||||
|
||||
func (l *jsonLogger) Error(err error) {
|
||||
_, _ = l.printf("ERROR", err.Error())
|
||||
}
|
||||
|
||||
func (l *jsonLogger) printf(level string, message string) (n int, err error) {
|
||||
buf := l.bufferPool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
defer l.bufferPool.Put(buf)
|
||||
|
||||
buf.WriteString(`{"time":"`)
|
||||
buf.WriteString(l.timeNow().Format(time.RFC3339Nano))
|
||||
buf.WriteString(`","level":"`)
|
||||
buf.WriteString(level)
|
||||
buf.WriteString(`","prefix":"echo","message":`)
|
||||
|
||||
buf.WriteString(strconv.Quote(message))
|
||||
buf.WriteString("}\n")
|
||||
|
||||
return l.writer.Write(buf.Bytes())
|
||||
}
|
||||
|
77
log_test.go
Normal file
77
log_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestJsonLogger_Write(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
when []byte
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "ok, write non JSONlike message",
|
||||
when: []byte("version: %v, build: %v"),
|
||||
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: %v, build: %v"}` + "\n",
|
||||
},
|
||||
{
|
||||
name: "ok, write quoted message",
|
||||
when: []byte(`version: "%v"`),
|
||||
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: \"%v\""}` + "\n",
|
||||
},
|
||||
{
|
||||
name: "ok, write JSON",
|
||||
when: []byte(`{"version": 123}` + "\n"),
|
||||
expect: `{"version": 123}` + "\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
logger := newJSONLogger(buf)
|
||||
logger.timeNow = func() time.Time {
|
||||
return time.Unix(1631045377, 0)
|
||||
}
|
||||
|
||||
_, err := logger.Write(tc.when)
|
||||
|
||||
result := buf.String()
|
||||
assert.Equal(t, tc.expect, result)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonLogger_Error(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenError error
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
whenError: ErrForbidden,
|
||||
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"ERROR","prefix":"echo","message":"code=403, message=Forbidden"}` + "\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
logger := newJSONLogger(buf)
|
||||
logger.timeNow = func() time.Time {
|
||||
return time.Unix(1631045377, 0)
|
||||
}
|
||||
|
||||
logger.Error(tc.whenError)
|
||||
|
||||
result := buf.String()
|
||||
assert.Equal(t, tc.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
13
middleware/DEVELOPMENT.md
Normal file
13
middleware/DEVELOPMENT.md
Normal file
@ -0,0 +1,13 @@
|
||||
# Development Guidelines for middlewares
|
||||
|
||||
// FIXME: add info about `MiddlewareConfigurator` interface
|
||||
|
||||
## Best practices:
|
||||
|
||||
* Do not use `panic` in middleware creator functions in case of invalid configuration.
|
||||
* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead
|
||||
because previous middlewares up in call chain could have logic for dealing with returned errors.
|
||||
* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they
|
||||
want to create middleware with panics or with returning errors on configuration errors.
|
||||
* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same.
|
||||
|
@ -1,64 +1,59 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
// BasicAuthConfig defines the config for BasicAuth middleware.
|
||||
BasicAuthConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// BasicAuthConfig defines the config for BasicAuthWithConfig middleware.
|
||||
type BasicAuthConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Validator is a function to validate BasicAuth credentials.
|
||||
// Required.
|
||||
Validator BasicAuthValidator
|
||||
// Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers
|
||||
// this function would be called once for each header until first valid result is returned
|
||||
// Required.
|
||||
Validator BasicAuthValidator
|
||||
|
||||
// Realm is a string to define realm attribute of BasicAuth.
|
||||
// Default value "Restricted".
|
||||
Realm string
|
||||
}
|
||||
// Realm is a string to define realm attribute of BasicAuthWithConfig.
|
||||
// Default value "Restricted".
|
||||
Realm string
|
||||
}
|
||||
|
||||
// BasicAuthValidator defines a function to validate BasicAuth credentials.
|
||||
BasicAuthValidator func(string, string, echo.Context) (bool, error)
|
||||
)
|
||||
// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials.
|
||||
type BasicAuthValidator func(c echo.Context, user string, password string) (bool, error)
|
||||
|
||||
const (
|
||||
basic = "basic"
|
||||
defaultRealm = "Restricted"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultBasicAuthConfig is the default BasicAuth middleware config.
|
||||
DefaultBasicAuthConfig = BasicAuthConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Realm: defaultRealm,
|
||||
}
|
||||
)
|
||||
|
||||
// BasicAuth returns an BasicAuth middleware.
|
||||
//
|
||||
// For valid credentials it calls the next handler.
|
||||
// For missing or invalid credentials, it sends "401 - Unauthorized" response.
|
||||
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
|
||||
c := DefaultBasicAuthConfig
|
||||
c.Validator = fn
|
||||
return BasicAuthWithConfig(c)
|
||||
return BasicAuthWithConfig(BasicAuthConfig{Validator: fn})
|
||||
}
|
||||
|
||||
// BasicAuthWithConfig returns an BasicAuth middleware with config.
|
||||
// See `BasicAuth()`.
|
||||
// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config.
|
||||
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration
|
||||
func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Validator == nil {
|
||||
panic("echo: basic-auth middleware requires a validator function")
|
||||
return nil, errors.New("echo basic-auth middleware requires a validator function")
|
||||
}
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultBasicAuthConfig.Skipper
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
if config.Realm == "" {
|
||||
config.Realm = defaultRealm
|
||||
@ -70,29 +65,33 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
auth := c.Request().Header.Get(echo.HeaderAuthorization)
|
||||
var lastError error
|
||||
l := len(basic)
|
||||
|
||||
if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
|
||||
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
|
||||
if err != nil {
|
||||
return err
|
||||
for _, auth := range c.Request().Header[echo.HeaderAuthorization] {
|
||||
if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
|
||||
continue
|
||||
}
|
||||
cred := string(b)
|
||||
for i := 0; i < len(cred); i++ {
|
||||
if cred[i] == ':' {
|
||||
// Verify credentials
|
||||
valid, err := config.Validator(cred[:i], cred[i+1:], c)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if valid {
|
||||
return next(c)
|
||||
}
|
||||
break
|
||||
|
||||
b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
|
||||
if errDecode != nil {
|
||||
lastError = fmt.Errorf("invalid basic auth value: %w", errDecode)
|
||||
continue
|
||||
}
|
||||
idx := bytes.IndexByte(b, ':')
|
||||
if idx >= 0 {
|
||||
valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:]))
|
||||
if errValidate != nil {
|
||||
lastError = errValidate
|
||||
} else if valid {
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastError != nil {
|
||||
return lastError
|
||||
}
|
||||
|
||||
realm := defaultRealm
|
||||
if config.Realm != defaultRealm {
|
||||
realm = strconv.Quote(config.Realm)
|
||||
@ -102,5 +101,5 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
||||
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
|
||||
return echo.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@ -12,60 +13,146 @@ import (
|
||||
)
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
f := func(u, p string, c echo.Context) (bool, error) {
|
||||
validatorFunc := func(c echo.Context, u, p string) (bool, error) {
|
||||
if u == "joe" && p == "secret" {
|
||||
return true, nil
|
||||
}
|
||||
if u == "error" {
|
||||
return false, errors.New(p)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
h := BasicAuth(f)(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
defaultConfig := BasicAuthConfig{Validator: validatorFunc}
|
||||
|
||||
assert := assert.New(t)
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenConfig BasicAuthConfig
|
||||
whenAuth []string
|
||||
expectHeader string
|
||||
expectErr string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
givenConfig: defaultConfig,
|
||||
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
|
||||
},
|
||||
{
|
||||
name: "ok, multiple",
|
||||
givenConfig: defaultConfig,
|
||||
whenAuth: []string{
|
||||
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
|
||||
basic + " NOT_BASE64",
|
||||
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nok, invalid Authorization header",
|
||||
givenConfig: defaultConfig,
|
||||
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
|
||||
expectHeader: basic + ` realm=Restricted`,
|
||||
expectErr: "code=401, message=Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "nok, not base64 Authorization header",
|
||||
givenConfig: defaultConfig,
|
||||
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
|
||||
expectErr: "invalid basic auth value: illegal base64 data at input byte 3",
|
||||
},
|
||||
{
|
||||
name: "nok, missing Authorization header",
|
||||
givenConfig: defaultConfig,
|
||||
expectHeader: basic + ` realm=Restricted`,
|
||||
expectErr: "code=401, message=Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "ok, realm",
|
||||
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
|
||||
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
|
||||
},
|
||||
{
|
||||
name: "ok, realm, case-insensitive header scheme",
|
||||
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
|
||||
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
|
||||
},
|
||||
{
|
||||
name: "nok, realm, invalid Authorization header",
|
||||
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
|
||||
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
|
||||
expectHeader: basic + ` realm="someRealm"`,
|
||||
expectErr: "code=401, message=Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "nok, validator func returns an error",
|
||||
givenConfig: defaultConfig,
|
||||
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
|
||||
expectErr: "my_error",
|
||||
},
|
||||
{
|
||||
name: "ok, skipped",
|
||||
givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool {
|
||||
return true
|
||||
}},
|
||||
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
|
||||
},
|
||||
}
|
||||
|
||||
// Valid credentials
|
||||
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(h(c))
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
|
||||
h = BasicAuthWithConfig(BasicAuthConfig{
|
||||
Skipper: nil,
|
||||
Validator: f,
|
||||
Realm: "someRealm",
|
||||
})(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
config := tc.givenConfig
|
||||
|
||||
// Valid credentials
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(h(c))
|
||||
mw, err := config.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Case-insensitive header scheme
|
||||
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(h(c))
|
||||
h := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusTeapot, "test")
|
||||
})
|
||||
|
||||
// Invalid credentials
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(http.StatusUnauthorized, he.Code)
|
||||
assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
if len(tc.whenAuth) != 0 {
|
||||
for _, a := range tc.whenAuth {
|
||||
req.Header.Add(echo.HeaderAuthorization, a)
|
||||
}
|
||||
}
|
||||
err = h(c)
|
||||
|
||||
// Missing Authorization header
|
||||
req.Header.Del(echo.HeaderAuthorization)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(http.StatusUnauthorized, he.Code)
|
||||
|
||||
// Invalid Authorization header
|
||||
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(http.StatusUnauthorized, he.Code)
|
||||
if tc.expectErr != "" {
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
assert.EqualError(t, err, tc.expectErr)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusTeapot, res.Code)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tc.expectHeader != "" {
|
||||
assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicAuth_panic(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
mw := BasicAuth(nil)
|
||||
assert.NotNil(t, mw)
|
||||
})
|
||||
|
||||
mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
assert.NotNil(t, mw)
|
||||
}
|
||||
|
||||
func TestBasicAuthWithConfig_panic(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
|
||||
assert.NotNil(t, mw)
|
||||
})
|
||||
|
||||
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) {
|
||||
return true, nil
|
||||
}})
|
||||
assert.NotNil(t, mw)
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@ -11,63 +12,56 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
// BodyDumpConfig defines the config for BodyDump middleware.
|
||||
BodyDumpConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// BodyDumpConfig defines the config for BodyDump middleware.
|
||||
type BodyDumpConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Handler receives request and response payload.
|
||||
// Required.
|
||||
Handler BodyDumpHandler
|
||||
}
|
||||
// Handler receives request and response payload.
|
||||
// Required.
|
||||
Handler BodyDumpHandler
|
||||
}
|
||||
|
||||
// BodyDumpHandler receives the request and response payload.
|
||||
BodyDumpHandler func(echo.Context, []byte, []byte)
|
||||
// BodyDumpHandler receives the request and response payload.
|
||||
type BodyDumpHandler func(c echo.Context, reqBody []byte, resBody []byte)
|
||||
|
||||
bodyDumpResponseWriter struct {
|
||||
io.Writer
|
||||
http.ResponseWriter
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultBodyDumpConfig is the default BodyDump middleware config.
|
||||
DefaultBodyDumpConfig = BodyDumpConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
}
|
||||
)
|
||||
type bodyDumpResponseWriter struct {
|
||||
io.Writer
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
// BodyDump returns a BodyDump middleware.
|
||||
//
|
||||
// BodyDump middleware captures the request and response payload and calls the
|
||||
// registered handler.
|
||||
func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
|
||||
c := DefaultBodyDumpConfig
|
||||
c.Handler = handler
|
||||
return BodyDumpWithConfig(c)
|
||||
return BodyDumpWithConfig(BodyDumpConfig{Handler: handler})
|
||||
}
|
||||
|
||||
// BodyDumpWithConfig returns a BodyDump middleware with config.
|
||||
// See: `BodyDump()`.
|
||||
func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration
|
||||
func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Handler == nil {
|
||||
panic("echo: body-dump middleware requires a handler function")
|
||||
return nil, errors.New("echo body-dump middleware requires a handler function")
|
||||
}
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultBodyDumpConfig.Skipper
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) (err error) {
|
||||
return func(c echo.Context) error {
|
||||
if config.Skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Request
|
||||
reqBody := []byte{}
|
||||
if c.Request().Body != nil { // Read
|
||||
if c.Request().Body != nil {
|
||||
reqBody, _ = ioutil.ReadAll(c.Request().Body)
|
||||
}
|
||||
c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset
|
||||
@ -78,16 +72,14 @@ func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
|
||||
writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer}
|
||||
c.Response().Writer = writer
|
||||
|
||||
if err = next(c); err != nil {
|
||||
c.Error(err)
|
||||
}
|
||||
err := next(c)
|
||||
|
||||
// Callback
|
||||
config.Handler(c, reqBody, resBody.Bytes())
|
||||
|
||||
return
|
||||
return err
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *bodyDumpResponseWriter) WriteHeader(code int) {
|
||||
|
@ -28,31 +28,48 @@ func TestBodyDump(t *testing.T) {
|
||||
|
||||
requestBody := ""
|
||||
responseBody := ""
|
||||
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
|
||||
mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {
|
||||
requestBody = string(reqBody)
|
||||
responseBody = string(resBody)
|
||||
})
|
||||
}}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
if assert.NoError(mw(h)(c)) {
|
||||
assert.Equal(requestBody, hw)
|
||||
assert.Equal(responseBody, hw)
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(hw, rec.Body.String())
|
||||
if assert.NoError(t, mw(h)(c)) {
|
||||
assert.Equal(t, requestBody, hw)
|
||||
assert.Equal(t, responseBody, hw)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, hw, rec.Body.String())
|
||||
}
|
||||
|
||||
// Must set default skipper
|
||||
BodyDumpWithConfig(BodyDumpConfig{
|
||||
Skipper: nil,
|
||||
Handler: func(c echo.Context, reqBody, resBody []byte) {
|
||||
requestBody = string(reqBody)
|
||||
responseBody = string(resBody)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBodyDumpFails(t *testing.T) {
|
||||
func TestBodyDump_skipper(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
isCalled := false
|
||||
mw, err := BodyDumpConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return true
|
||||
},
|
||||
Handler: func(c echo.Context, reqBody, resBody []byte) {
|
||||
isCalled = true
|
||||
},
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}"))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h := func(c echo.Context) error {
|
||||
return errors.New("some error")
|
||||
}
|
||||
|
||||
err = mw(h)(c)
|
||||
assert.EqualError(t, err, "some error")
|
||||
assert.False(t, isCalled)
|
||||
}
|
||||
|
||||
func TestBodyDump_fails(t *testing.T) {
|
||||
e := echo.New()
|
||||
hw := "Hello, World!"
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
|
||||
@ -62,30 +79,37 @@ func TestBodyDumpFails(t *testing.T) {
|
||||
return errors.New("some error")
|
||||
}
|
||||
|
||||
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
|
||||
mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
if !assert.Error(t, mw(h)(c)) {
|
||||
t.FailNow()
|
||||
}
|
||||
err = mw(h)(c)
|
||||
assert.EqualError(t, err, "some error")
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
}
|
||||
|
||||
func TestBodyDumpWithConfig_panic(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
mw = BodyDumpWithConfig(BodyDumpConfig{
|
||||
mw := BodyDumpWithConfig(BodyDumpConfig{
|
||||
Skipper: nil,
|
||||
Handler: nil,
|
||||
})
|
||||
assert.NotNil(t, mw)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
mw = BodyDumpWithConfig(BodyDumpConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return true
|
||||
},
|
||||
Handler: func(c echo.Context, reqBody, resBody []byte) {
|
||||
},
|
||||
})
|
||||
|
||||
if !assert.Error(t, mw(h)(c)) {
|
||||
t.FailNow()
|
||||
}
|
||||
mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}})
|
||||
assert.NotNil(t, mw)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBodyDump_panic(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
mw := BodyDump(nil)
|
||||
assert.NotNil(t, mw)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
|
||||
})
|
||||
}
|
||||
|
@ -1,98 +1,83 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/bytes"
|
||||
)
|
||||
|
||||
type (
|
||||
// BodyLimitConfig defines the config for BodyLimit middleware.
|
||||
BodyLimitConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// BodyLimitConfig defines the config for BodyLimitWithConfig middleware.
|
||||
type BodyLimitConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Maximum allowed size for a request body, it can be specified
|
||||
// as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P.
|
||||
Limit string `yaml:"limit"`
|
||||
limit int64
|
||||
}
|
||||
// LimitBytes is maximum allowed size in bytes for a request body
|
||||
LimitBytes int64
|
||||
}
|
||||
|
||||
limitedReader struct {
|
||||
BodyLimitConfig
|
||||
reader io.ReadCloser
|
||||
read int64
|
||||
context echo.Context
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultBodyLimitConfig is the default BodyLimit middleware config.
|
||||
DefaultBodyLimitConfig = BodyLimitConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
}
|
||||
)
|
||||
type limitedReader struct {
|
||||
BodyLimitConfig
|
||||
reader io.ReadCloser
|
||||
read int64
|
||||
context echo.Context
|
||||
}
|
||||
|
||||
// BodyLimit returns a BodyLimit middleware.
|
||||
//
|
||||
// BodyLimit middleware sets the maximum allowed size for a request body, if the
|
||||
// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
|
||||
// response. The BodyLimit is determined based on both `Content-Length` request
|
||||
// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it
|
||||
// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request
|
||||
// header and actual content read, which makes it super secure.
|
||||
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
|
||||
// G, T or P.
|
||||
func BodyLimit(limit string) echo.MiddlewareFunc {
|
||||
c := DefaultBodyLimitConfig
|
||||
c.Limit = limit
|
||||
return BodyLimitWithConfig(c)
|
||||
func BodyLimit(limitBytes int64) echo.MiddlewareFunc {
|
||||
return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes})
|
||||
}
|
||||
|
||||
// BodyLimitWithConfig returns a BodyLimit middleware with config.
|
||||
// See: `BodyLimit()`.
|
||||
// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for
|
||||
// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response.
|
||||
// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which
|
||||
// makes it super secure.
|
||||
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultBodyLimitConfig.Skipper
|
||||
}
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
limit, err := bytes.Parse(config.Limit)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit))
|
||||
// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration
|
||||
func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
pool := sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &limitedReader{BodyLimitConfig: config}
|
||||
},
|
||||
}
|
||||
config.limit = limit
|
||||
pool := limitedReaderPool(config)
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if config.Skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
req := c.Request()
|
||||
|
||||
// Based on content length
|
||||
if req.ContentLength > config.limit {
|
||||
if req.ContentLength > config.LimitBytes {
|
||||
return echo.ErrStatusRequestEntityTooLarge
|
||||
}
|
||||
|
||||
// Based on content read
|
||||
r := pool.Get().(*limitedReader)
|
||||
r.Reset(req.Body, c)
|
||||
r.Reset(c, req.Body)
|
||||
defer pool.Put(r)
|
||||
req.Body = r
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *limitedReader) Read(b []byte) (n int, err error) {
|
||||
n, err = r.reader.Read(b)
|
||||
r.read += int64(n)
|
||||
if r.read > r.limit {
|
||||
if r.read > r.LimitBytes {
|
||||
return n, echo.ErrStatusRequestEntityTooLarge
|
||||
}
|
||||
return
|
||||
@ -102,16 +87,8 @@ func (r *limitedReader) Close() error {
|
||||
return r.reader.Close()
|
||||
}
|
||||
|
||||
func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) {
|
||||
func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) {
|
||||
r.reader = reader
|
||||
r.context = context
|
||||
r.read = 0
|
||||
}
|
||||
|
||||
func limitedReaderPool(c BodyLimitConfig) sync.Pool {
|
||||
return sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &limitedReader{BodyLimitConfig: c}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -11,6 +11,137 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
|
||||
e := echo.New()
|
||||
hw := []byte("Hello, World!")
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h := func(c echo.Context) error {
|
||||
body, err := ioutil.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.String(http.StatusOK, string(body))
|
||||
}
|
||||
|
||||
// Based on content length (within limit)
|
||||
mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = mw(h)(c)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, hw, rec.Body.Bytes())
|
||||
}
|
||||
|
||||
// Based on content read (overlimit)
|
||||
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
he := mw(h)(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
|
||||
|
||||
// Based on content read (within limit)
|
||||
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
|
||||
mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
err = mw(h)(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "Hello, World!", rec.Body.String())
|
||||
|
||||
// Based on content read (overlimit)
|
||||
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
he = mw(h)(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
|
||||
}
|
||||
|
||||
func TestBodyLimitReader(t *testing.T) {
|
||||
hw := []byte("Hello, World!")
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
config := BodyLimitConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
LimitBytes: 2,
|
||||
}
|
||||
reader := &limitedReader{
|
||||
BodyLimitConfig: config,
|
||||
reader: ioutil.NopCloser(bytes.NewReader(hw)),
|
||||
context: e.NewContext(req, rec),
|
||||
}
|
||||
|
||||
// read all should return ErrStatusRequestEntityTooLarge
|
||||
_, err := ioutil.ReadAll(reader)
|
||||
he := err.(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
|
||||
|
||||
// reset reader and read two bytes must succeed
|
||||
bt := make([]byte, 2)
|
||||
reader.Reset(e.NewContext(req, rec), ioutil.NopCloser(bytes.NewReader(hw)))
|
||||
n, err := reader.Read(bt)
|
||||
assert.Equal(t, 2, n)
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
||||
func TestBodyLimit_skipper(t *testing.T) {
|
||||
e := echo.New()
|
||||
h := func(c echo.Context) error {
|
||||
body, err := ioutil.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.String(http.StatusOK, string(body))
|
||||
}
|
||||
mw, err := BodyLimitConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return true
|
||||
},
|
||||
LimitBytes: 2,
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
hw := []byte("Hello, World!")
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err = mw(h)(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, hw, rec.Body.Bytes())
|
||||
}
|
||||
|
||||
func TestBodyLimitWithConfig(t *testing.T) {
|
||||
e := echo.New()
|
||||
hw := []byte("Hello, World!")
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h := func(c echo.Context) error {
|
||||
body, err := ioutil.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.String(http.StatusOK, string(body))
|
||||
}
|
||||
|
||||
mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB})
|
||||
|
||||
err := mw(h)(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, hw, rec.Body.Bytes())
|
||||
}
|
||||
|
||||
func TestBodyLimit(t *testing.T) {
|
||||
e := echo.New()
|
||||
hw := []byte("Hello, World!")
|
||||
@ -25,61 +156,10 @@ func TestBodyLimit(t *testing.T) {
|
||||
return c.String(http.StatusOK, string(body))
|
||||
}
|
||||
|
||||
assert := assert.New(t)
|
||||
mw := BodyLimit(2 * MB)
|
||||
|
||||
// Based on content length (within limit)
|
||||
if assert.NoError(BodyLimit("2M")(h)(c)) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal(hw, rec.Body.Bytes())
|
||||
}
|
||||
|
||||
// Based on content read (overlimit)
|
||||
he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
|
||||
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
|
||||
|
||||
// Based on content read (within limit)
|
||||
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
if assert.NoError(BodyLimit("2M")(h)(c)) {
|
||||
assert.Equal(http.StatusOK, rec.Code)
|
||||
assert.Equal("Hello, World!", rec.Body.String())
|
||||
}
|
||||
|
||||
// Based on content read (overlimit)
|
||||
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
|
||||
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
|
||||
}
|
||||
|
||||
func TestBodyLimitReader(t *testing.T) {
|
||||
hw := []byte("Hello, World!")
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
config := BodyLimitConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Limit: "2B",
|
||||
limit: 2,
|
||||
}
|
||||
reader := &limitedReader{
|
||||
BodyLimitConfig: config,
|
||||
reader: ioutil.NopCloser(bytes.NewReader(hw)),
|
||||
context: e.NewContext(req, rec),
|
||||
}
|
||||
|
||||
// read all should return ErrStatusRequestEntityTooLarge
|
||||
_, err := ioutil.ReadAll(reader)
|
||||
he := err.(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
|
||||
|
||||
// reset reader and read two bytes must succeed
|
||||
bt := make([]byte, 2)
|
||||
reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec))
|
||||
n, err := reader.Read(bt)
|
||||
assert.Equal(t, 2, n)
|
||||
assert.Equal(t, nil, err)
|
||||
err := mw(h)(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, hw, rec.Body.Bytes())
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@ -13,50 +14,45 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
// GzipConfig defines the config for Gzip middleware.
|
||||
GzipConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Gzip compression level.
|
||||
// Optional. Default value -1.
|
||||
Level int `yaml:"level"`
|
||||
}
|
||||
|
||||
gzipResponseWriter struct {
|
||||
io.Writer
|
||||
http.ResponseWriter
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
gzipScheme = "gzip"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultGzipConfig is the default Gzip middleware config.
|
||||
DefaultGzipConfig = GzipConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Level: -1,
|
||||
}
|
||||
)
|
||||
// GzipConfig defines the config for Gzip middleware.
|
||||
type GzipConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Gzip returns a middleware which compresses HTTP response using gzip compression
|
||||
// scheme.
|
||||
func Gzip() echo.MiddlewareFunc {
|
||||
return GzipWithConfig(DefaultGzipConfig)
|
||||
// Gzip compression level.
|
||||
// Optional. Default value -1.
|
||||
Level int
|
||||
}
|
||||
|
||||
// GzipWithConfig return Gzip middleware with config.
|
||||
// See: `Gzip()`.
|
||||
type gzipResponseWriter struct {
|
||||
io.Writer
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
|
||||
func Gzip() echo.MiddlewareFunc {
|
||||
return GzipWithConfig(GzipConfig{})
|
||||
}
|
||||
|
||||
// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
|
||||
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration
|
||||
func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultGzipConfig.Skipper
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
|
||||
return nil, errors.New("invalid gzip level")
|
||||
}
|
||||
if config.Level == 0 {
|
||||
config.Level = DefaultGzipConfig.Level
|
||||
config.Level = -1
|
||||
}
|
||||
|
||||
pool := gzipCompressPool(config)
|
||||
@ -97,7 +93,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) WriteHeader(code int) {
|
||||
|
@ -3,94 +3,128 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGzip(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
func TestGzip_NoAcceptEncodingHeader(t *testing.T) {
|
||||
// Skip if no Accept-Encoding header
|
||||
h := Gzip()(func(c echo.Context) error {
|
||||
c.Response().Write([]byte("test")) // For Content-Type sniffing
|
||||
return nil
|
||||
})
|
||||
h(c)
|
||||
|
||||
assert := assert.New(t)
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
assert.Equal("test", rec.Body.String())
|
||||
err := h(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Gzip
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
assert.Equal(t, "test", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestMustGzipWithConfig_panics(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
GzipWithConfig(GzipConfig{Level: 999})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGzip_AcceptEncodingHeader(t *testing.T) {
|
||||
h := Gzip()(func(c echo.Context) error {
|
||||
c.Response().Write([]byte("test")) // For Content-Type sniffing
|
||||
return nil
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
h(c)
|
||||
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
|
||||
assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := h(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
|
||||
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
|
||||
|
||||
r, err := gzip.NewReader(rec.Body)
|
||||
if assert.NoError(err) {
|
||||
buf := new(bytes.Buffer)
|
||||
defer r.Close()
|
||||
buf.ReadFrom(r)
|
||||
assert.Equal("test", buf.String())
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
buf := new(bytes.Buffer)
|
||||
defer r.Close()
|
||||
buf.ReadFrom(r)
|
||||
assert.Equal(t, "test", buf.String())
|
||||
}
|
||||
|
||||
chunkBuf := make([]byte, 5)
|
||||
|
||||
// Gzip chunked
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
func TestGzip_chunked(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
rec = httptest.NewRecorder()
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c = e.NewContext(req, rec)
|
||||
Gzip()(func(c echo.Context) error {
|
||||
chunkChan := make(chan struct{})
|
||||
waitChan := make(chan struct{})
|
||||
h := Gzip()(func(c echo.Context) error {
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Transfer-Encoding", "chunked")
|
||||
|
||||
// Write and flush the first part of the data
|
||||
c.Response().Write([]byte("test\n"))
|
||||
c.Response().Write([]byte("first\n"))
|
||||
c.Response().Flush()
|
||||
|
||||
// Read the first part of the data
|
||||
assert.True(rec.Flushed)
|
||||
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
|
||||
r.Reset(rec.Body)
|
||||
|
||||
_, err = io.ReadFull(r, chunkBuf)
|
||||
assert.NoError(err)
|
||||
assert.Equal("test\n", string(chunkBuf))
|
||||
chunkChan <- struct{}{}
|
||||
<-waitChan
|
||||
|
||||
// Write and flush the second part of the data
|
||||
c.Response().Write([]byte("test\n"))
|
||||
c.Response().Write([]byte("second\n"))
|
||||
c.Response().Flush()
|
||||
|
||||
_, err = io.ReadFull(r, chunkBuf)
|
||||
assert.NoError(err)
|
||||
assert.Equal("test\n", string(chunkBuf))
|
||||
chunkChan <- struct{}{}
|
||||
<-waitChan
|
||||
|
||||
// Write the final part of the data and return
|
||||
c.Response().Write([]byte("test"))
|
||||
return nil
|
||||
})(c)
|
||||
c.Response().Write([]byte("third"))
|
||||
|
||||
chunkChan <- struct{}{}
|
||||
return nil
|
||||
})
|
||||
|
||||
go func() {
|
||||
err := h(c)
|
||||
chunkChan <- struct{}{}
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
<-chunkChan // wait for first write
|
||||
waitChan <- struct{}{}
|
||||
|
||||
<-chunkChan // wait for second write
|
||||
waitChan <- struct{}{}
|
||||
|
||||
<-chunkChan // wait for final write in handler
|
||||
<-chunkChan // wait for return from handler
|
||||
time.Sleep(5 * time.Millisecond) // to have time for flushing
|
||||
|
||||
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
|
||||
|
||||
r, err := gzip.NewReader(rec.Body)
|
||||
assert.NoError(t, err)
|
||||
buf := new(bytes.Buffer)
|
||||
defer r.Close()
|
||||
buf.ReadFrom(r)
|
||||
assert.Equal("test", buf.String())
|
||||
assert.Equal(t, "first\nsecond\nthird", buf.String())
|
||||
}
|
||||
|
||||
func TestGzipNoContent(t *testing.T) {
|
||||
func TestGzip_NoContent(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
@ -106,7 +140,7 @@ func TestGzipNoContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGzipErrorReturned(t *testing.T) {
|
||||
func TestGzip_ErrorReturned(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Use(Gzip())
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
@ -120,31 +154,25 @@ func TestGzipErrorReturned(t *testing.T) {
|
||||
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
|
||||
}
|
||||
|
||||
func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
|
||||
e := echo.New()
|
||||
// Invalid level
|
||||
e.Use(GzipWithConfig(GzipConfig{Level: 12}))
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
c.Response().Write([]byte("test"))
|
||||
return nil
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "gzip")
|
||||
func TestGzipWithConfig_invalidLevel(t *testing.T) {
|
||||
mw, err := GzipConfig{Level: 12}.ToMiddleware()
|
||||
assert.EqualError(t, err, "invalid gzip level")
|
||||
assert.Nil(t, mw)
|
||||
}
|
||||
|
||||
// Issue #806
|
||||
func TestGzipWithStatic(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Filesystem = os.DirFS("../")
|
||||
|
||||
e.Use(Gzip())
|
||||
e.Static("/test", "../_fixture/images")
|
||||
e.Static("/test", "_fixture/images")
|
||||
req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil)
|
||||
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
// Data is written out in chunks when Content-Length == "", so only
|
||||
// validate the content length if it's not set.
|
||||
|
@ -9,60 +9,56 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
// CORSConfig defines the config for CORS middleware.
|
||||
CORSConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// CORSConfig defines the config for CORS middleware.
|
||||
type CORSConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// AllowOrigin defines a list of origins that may access the resource.
|
||||
// Optional. Default value []string{"*"}.
|
||||
AllowOrigins []string `yaml:"allow_origins"`
|
||||
// AllowOrigin defines a list of origins that may access the resource.
|
||||
// Optional. Default value []string{"*"}.
|
||||
AllowOrigins []string
|
||||
|
||||
// AllowOriginFunc is a custom function to validate the origin. It takes the
|
||||
// origin as an argument and returns true if allowed or false otherwise. If
|
||||
// an error is returned, it is returned by the handler. If this option is
|
||||
// set, AllowOrigins is ignored.
|
||||
// Optional.
|
||||
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
|
||||
// AllowOriginFunc is a custom function to validate the origin. It takes the
|
||||
// origin as an argument and returns true if allowed or false otherwise. If
|
||||
// an error is returned, it is returned by the handler. If this option is
|
||||
// set, AllowOrigins is ignored.
|
||||
// Optional.
|
||||
AllowOriginFunc func(origin string) (bool, error)
|
||||
|
||||
// AllowMethods defines a list methods allowed when accessing the resource.
|
||||
// This is used in response to a preflight request.
|
||||
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
||||
AllowMethods []string `yaml:"allow_methods"`
|
||||
// AllowMethods defines a list methods allowed when accessing the resource.
|
||||
// This is used in response to a preflight request.
|
||||
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
||||
AllowMethods []string
|
||||
|
||||
// AllowHeaders defines a list of request headers that can be used when
|
||||
// making the actual request. This is in response to a preflight request.
|
||||
// Optional. Default value []string{}.
|
||||
AllowHeaders []string `yaml:"allow_headers"`
|
||||
// AllowHeaders defines a list of request headers that can be used when
|
||||
// making the actual request. This is in response to a preflight request.
|
||||
// Optional. Default value []string{}.
|
||||
AllowHeaders []string
|
||||
|
||||
// AllowCredentials indicates whether or not the response to the request
|
||||
// can be exposed when the credentials flag is true. When used as part of
|
||||
// a response to a preflight request, this indicates whether or not the
|
||||
// actual request can be made using credentials.
|
||||
// Optional. Default value false.
|
||||
AllowCredentials bool `yaml:"allow_credentials"`
|
||||
// AllowCredentials indicates whether or not the response to the request
|
||||
// can be exposed when the credentials flag is true. When used as part of
|
||||
// a response to a preflight request, this indicates whether or not the
|
||||
// actual request can be made using credentials.
|
||||
// Optional. Default value false.
|
||||
AllowCredentials bool
|
||||
|
||||
// ExposeHeaders defines a whitelist headers that clients are allowed to
|
||||
// access.
|
||||
// Optional. Default value []string{}.
|
||||
ExposeHeaders []string `yaml:"expose_headers"`
|
||||
// ExposeHeaders defines a whitelist headers that clients are allowed to
|
||||
// access.
|
||||
// Optional. Default value []string{}.
|
||||
ExposeHeaders []string
|
||||
|
||||
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached.
|
||||
// Optional. Default value 0.
|
||||
MaxAge int `yaml:"max_age"`
|
||||
}
|
||||
)
|
||||
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached.
|
||||
// Optional. Default value 0.
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultCORSConfig is the default CORS middleware config.
|
||||
DefaultCORSConfig = CORSConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
AllowOrigins: []string{"*"},
|
||||
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
|
||||
}
|
||||
)
|
||||
// DefaultCORSConfig is the default CORS middleware config.
|
||||
var DefaultCORSConfig = CORSConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
AllowOrigins: []string{"*"},
|
||||
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
|
||||
}
|
||||
|
||||
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
|
||||
// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
|
||||
@ -70,9 +66,14 @@ func CORS() echo.MiddlewareFunc {
|
||||
return CORSWithConfig(DefaultCORSConfig)
|
||||
}
|
||||
|
||||
// CORSWithConfig returns a CORS middleware with config.
|
||||
// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
|
||||
// See: `CORS()`.
|
||||
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration
|
||||
func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
// Defaults
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultCORSConfig.Skipper
|
||||
@ -207,5 +208,5 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ func TestCORS(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h := CORS()(echo.NotFoundHandler)
|
||||
h := CORS()(func(c echo.Context) error { return echo.ErrNotFound })
|
||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
||||
h(c)
|
||||
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
@ -26,7 +26,7 @@ func TestCORS(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
h = CORS()(echo.NotFoundHandler)
|
||||
h = CORS()(func(c echo.Context) error { return echo.ErrNotFound })
|
||||
h(c)
|
||||
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||
|
||||
@ -38,7 +38,7 @@ func TestCORS(t *testing.T) {
|
||||
AllowOrigins: []string{"localhost"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 3600,
|
||||
})(echo.NotFoundHandler)
|
||||
})(func(c echo.Context) error { return echo.ErrNotFound })
|
||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
||||
h(c)
|
||||
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
@ -55,7 +55,7 @@ func TestCORS(t *testing.T) {
|
||||
AllowCredentials: true,
|
||||
MaxAge: 3600,
|
||||
})
|
||||
h = cors(echo.NotFoundHandler)
|
||||
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
|
||||
h(c)
|
||||
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
|
||||
@ -73,7 +73,7 @@ func TestCORS(t *testing.T) {
|
||||
AllowCredentials: true,
|
||||
MaxAge: 3600,
|
||||
})
|
||||
h = cors(echo.NotFoundHandler)
|
||||
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
|
||||
h(c)
|
||||
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
|
||||
@ -90,7 +90,7 @@ func TestCORS(t *testing.T) {
|
||||
cors = CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{"*"},
|
||||
})
|
||||
h = cors(echo.NotFoundHandler)
|
||||
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
|
||||
h(c)
|
||||
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
|
||||
@ -104,7 +104,7 @@ func TestCORS(t *testing.T) {
|
||||
cors = CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{"http://*.example.com"},
|
||||
})
|
||||
h = cors(echo.NotFoundHandler)
|
||||
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
|
||||
h(c)
|
||||
assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
|
||||
@ -149,7 +149,7 @@ func Test_allowOriginScheme(t *testing.T) {
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{tt.pattern},
|
||||
})
|
||||
h := cors(echo.NotFoundHandler)
|
||||
h := cors(func(c echo.Context) error { return echo.ErrNotFound })
|
||||
h(c)
|
||||
|
||||
if tt.expected {
|
||||
@ -240,7 +240,7 @@ func Test_allowOriginSubdomain(t *testing.T) {
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{tt.pattern},
|
||||
})
|
||||
h := cors(echo.NotFoundHandler)
|
||||
h := cors(func(c echo.Context) error { return echo.ErrNotFound })
|
||||
h(c)
|
||||
|
||||
if tt.expected {
|
||||
@ -331,7 +331,7 @@ func TestCorsHeaders(t *testing.T) {
|
||||
//AllowCredentials: true,
|
||||
//MaxAge: 3600,
|
||||
})
|
||||
h := cors(echo.NotFoundHandler)
|
||||
h := cors(func(c echo.Context) error { return echo.ErrNotFound })
|
||||
h(c)
|
||||
|
||||
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
|
||||
@ -387,11 +387,11 @@ func Test_allowOriginFunc(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
req.Header.Set(echo.HeaderOrigin, origin)
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOriginFunc: allowOriginFunc,
|
||||
})
|
||||
h := cors(echo.NotFoundHandler)
|
||||
err := h(c)
|
||||
cors, err := CORSConfig{AllowOriginFunc: allowOriginFunc}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
h := cors(func(c echo.Context) error { return echo.ErrNotFound })
|
||||
err = h(c)
|
||||
|
||||
expected, expectedErr := allowOriginFunc(origin)
|
||||
if expectedErr != nil {
|
||||
|
@ -8,89 +8,90 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/random"
|
||||
)
|
||||
|
||||
type (
|
||||
// CSRFConfig defines the config for CSRF middleware.
|
||||
CSRFConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// CSRFConfig defines the config for CSRF middleware.
|
||||
type CSRFConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// TokenLength is the length of the generated token.
|
||||
TokenLength uint8 `yaml:"token_length"`
|
||||
// Optional. Default value 32.
|
||||
// TokenLength is the length of the generated token.
|
||||
TokenLength uint8
|
||||
// Optional. Default value 32.
|
||||
|
||||
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
||||
// to extract token from the request.
|
||||
// Optional. Default value "header:X-CSRF-Token".
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "form:<name>"
|
||||
// - "query:<name>"
|
||||
TokenLookup string `yaml:"token_lookup"`
|
||||
// Generator defines a function to generate token.
|
||||
// Optional. Defaults tp randomString(TokenLength).
|
||||
Generator func() string
|
||||
|
||||
// Context key to store generated CSRF token into context.
|
||||
// Optional. Default value "csrf".
|
||||
ContextKey string `yaml:"context_key"`
|
||||
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
||||
// to extract token from the request.
|
||||
// Optional. Default value "header:X-CSRF-Token".
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "form:<name>"
|
||||
// - "query:<name>"
|
||||
TokenLookup string
|
||||
|
||||
// Name of the CSRF cookie. This cookie will store CSRF token.
|
||||
// Optional. Default value "csrf".
|
||||
CookieName string `yaml:"cookie_name"`
|
||||
// Context key to store generated CSRF token into context.
|
||||
// Optional. Default value "csrf".
|
||||
ContextKey string
|
||||
|
||||
// Domain of the CSRF cookie.
|
||||
// Optional. Default value none.
|
||||
CookieDomain string `yaml:"cookie_domain"`
|
||||
// Name of the CSRF cookie. This cookie will store CSRF token.
|
||||
// Optional. Default value "csrf".
|
||||
CookieName string
|
||||
|
||||
// Path of the CSRF cookie.
|
||||
// Optional. Default value none.
|
||||
CookiePath string `yaml:"cookie_path"`
|
||||
// Domain of the CSRF cookie.
|
||||
// Optional. Default value none.
|
||||
CookieDomain string
|
||||
|
||||
// Max age (in seconds) of the CSRF cookie.
|
||||
// Optional. Default value 86400 (24hr).
|
||||
CookieMaxAge int `yaml:"cookie_max_age"`
|
||||
// Path of the CSRF cookie.
|
||||
// Optional. Default value none.
|
||||
CookiePath string
|
||||
|
||||
// Indicates if CSRF cookie is secure.
|
||||
// Optional. Default value false.
|
||||
CookieSecure bool `yaml:"cookie_secure"`
|
||||
// Max age (in seconds) of the CSRF cookie.
|
||||
// Optional. Default value 86400 (24hr).
|
||||
CookieMaxAge int
|
||||
|
||||
// Indicates if CSRF cookie is HTTP only.
|
||||
// Optional. Default value false.
|
||||
CookieHTTPOnly bool `yaml:"cookie_http_only"`
|
||||
// Indicates if CSRF cookie is secure.
|
||||
// Optional. Default value false.
|
||||
CookieSecure bool
|
||||
|
||||
// Indicates SameSite mode of the CSRF cookie.
|
||||
// Optional. Default value SameSiteDefaultMode.
|
||||
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
|
||||
}
|
||||
// Indicates if CSRF cookie is HTTP only.
|
||||
// Optional. Default value false.
|
||||
CookieHTTPOnly bool
|
||||
|
||||
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
|
||||
// either a token or an error.
|
||||
csrfTokenExtractor func(echo.Context) (string, error)
|
||||
)
|
||||
// Indicates SameSite mode of the CSRF cookie.
|
||||
// Optional. Default value SameSiteDefaultMode.
|
||||
CookieSameSite http.SameSite
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultCSRFConfig is the default CSRF middleware config.
|
||||
DefaultCSRFConfig = CSRFConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
TokenLength: 32,
|
||||
TokenLookup: "header:" + echo.HeaderXCSRFToken,
|
||||
ContextKey: "csrf",
|
||||
CookieName: "_csrf",
|
||||
CookieMaxAge: 86400,
|
||||
CookieSameSite: http.SameSiteDefaultMode,
|
||||
}
|
||||
)
|
||||
// csrfTokenExtractor defines a function that takes `echo.Context` and returns either a token or an error.
|
||||
type csrfTokenExtractor func(echo.Context) (string, error)
|
||||
|
||||
// DefaultCSRFConfig is the default CSRF middleware config.
|
||||
var DefaultCSRFConfig = CSRFConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
TokenLength: 32,
|
||||
TokenLookup: "header:" + echo.HeaderXCSRFToken,
|
||||
ContextKey: "csrf",
|
||||
CookieName: "_csrf",
|
||||
CookieMaxAge: 86400,
|
||||
CookieSameSite: http.SameSiteDefaultMode,
|
||||
}
|
||||
|
||||
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
|
||||
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
|
||||
func CSRF() echo.MiddlewareFunc {
|
||||
c := DefaultCSRFConfig
|
||||
return CSRFWithConfig(c)
|
||||
return CSRFWithConfig(DefaultCSRFConfig)
|
||||
}
|
||||
|
||||
// CSRFWithConfig returns a CSRF middleware with config.
|
||||
// See `CSRF()`.
|
||||
// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration.
|
||||
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration
|
||||
func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
// Defaults
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultCSRFConfig.Skipper
|
||||
@ -98,6 +99,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
if config.TokenLength == 0 {
|
||||
config.TokenLength = DefaultCSRFConfig.TokenLength
|
||||
}
|
||||
if config.Generator == nil {
|
||||
config.Generator = createRandomStringGenerator(config.TokenLength)
|
||||
}
|
||||
if config.TokenLookup == "" {
|
||||
config.TokenLookup = DefaultCSRFConfig.TokenLookup
|
||||
}
|
||||
@ -136,7 +140,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
|
||||
// Generate token
|
||||
if err != nil {
|
||||
token = random.String(config.TokenLength)
|
||||
token = config.Generator()
|
||||
} else {
|
||||
// Reuse token
|
||||
token = k.Value
|
||||
@ -181,7 +185,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
||||
|
@ -9,11 +9,26 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/random"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCSRF(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
csrf := CSRF()
|
||||
h := csrf(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
// Generate CSRF token
|
||||
h(c)
|
||||
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
|
||||
|
||||
}
|
||||
|
||||
func TestMustCSRFWithConfig(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@ -43,7 +58,7 @@ func TestCSRF(t *testing.T) {
|
||||
assert.Error(t, h(c))
|
||||
|
||||
// Valid CSRF token
|
||||
token := random.String(16)
|
||||
token := randomString(16)
|
||||
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
|
||||
req.Header.Set(echo.HeaderXCSRFToken, token)
|
||||
if assert.NoError(t, h(c)) {
|
||||
@ -145,9 +160,10 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
csrf := CSRFWithConfig(CSRFConfig{
|
||||
csrf, err := CSRFConfig{
|
||||
CookieSameSite: http.SameSiteNoneMode,
|
||||
})
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
h := csrf(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
|
@ -11,18 +11,16 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
// DecompressConfig defines the config for Decompress middleware.
|
||||
DecompressConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// DecompressConfig defines the config for Decompress middleware.
|
||||
type DecompressConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
|
||||
GzipDecompressPool Decompressor
|
||||
}
|
||||
)
|
||||
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
|
||||
GzipDecompressPool Decompressor
|
||||
}
|
||||
|
||||
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
|
||||
// GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
|
||||
const GZIPEncoding string = "gzip"
|
||||
|
||||
// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
|
||||
@ -30,14 +28,6 @@ type Decompressor interface {
|
||||
gzipDecompressPool() sync.Pool
|
||||
}
|
||||
|
||||
var (
|
||||
//DefaultDecompressConfig defines the config for decompress middleware
|
||||
DefaultDecompressConfig = DecompressConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
GzipDecompressPool: &DefaultGzipDecompressPool{},
|
||||
}
|
||||
)
|
||||
|
||||
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
|
||||
type DefaultGzipDecompressPool struct {
|
||||
}
|
||||
@ -65,19 +55,23 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
|
||||
}
|
||||
}
|
||||
|
||||
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
|
||||
// Decompress decompresses request body based if content encoding type is set to "gzip" with default config
|
||||
func Decompress() echo.MiddlewareFunc {
|
||||
return DecompressWithConfig(DefaultDecompressConfig)
|
||||
return DecompressWithConfig(DecompressConfig{})
|
||||
}
|
||||
|
||||
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
|
||||
// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration.
|
||||
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration
|
||||
func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultGzipConfig.Skipper
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
if config.GzipDecompressPool == nil {
|
||||
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
|
||||
config.GzipDecompressPool = &DefaultGzipDecompressPool{}
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@ -116,5 +110,5 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
@ -17,6 +17,31 @@ import (
|
||||
|
||||
func TestDecompress(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
h := Decompress()(func(c echo.Context) error {
|
||||
c.Response().Write([]byte("test")) // For Content-Type sniffing
|
||||
return nil
|
||||
})
|
||||
|
||||
// Decompress request body
|
||||
body := `{"name": "echo"}`
|
||||
gz, _ := gzipString(body)
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := h(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||
b, err := ioutil.ReadAll(req.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, body, string(b))
|
||||
}
|
||||
|
||||
func TestDecompress_skippedIfNoHeader(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
@ -26,39 +51,42 @@ func TestDecompress(t *testing.T) {
|
||||
c.Response().Write([]byte("test")) // For Content-Type sniffing
|
||||
return nil
|
||||
})
|
||||
h(c)
|
||||
|
||||
assert := assert.New(t)
|
||||
assert.Equal("test", rec.Body.String())
|
||||
err := h(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test", rec.Body.String())
|
||||
|
||||
// Decompress
|
||||
body := `{"name": "echo"}`
|
||||
gz, _ := gzipString(body)
|
||||
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
h(c)
|
||||
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||
b, err := ioutil.ReadAll(req.Body)
|
||||
assert.NoError(err)
|
||||
assert.Equal(body, string(b))
|
||||
}
|
||||
|
||||
func TestDecompressDefaultConfig(t *testing.T) {
|
||||
func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
|
||||
h, err := DecompressConfig{}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = h(func(c echo.Context) error {
|
||||
c.Response().Write([]byte("test")) // For Content-Type sniffing
|
||||
return nil
|
||||
})(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test", rec.Body.String())
|
||||
|
||||
}
|
||||
|
||||
func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
h := Decompress()(func(c echo.Context) error {
|
||||
c.Response().Write([]byte("test")) // For Content-Type sniffing
|
||||
return nil
|
||||
})
|
||||
h(c)
|
||||
|
||||
assert := assert.New(t)
|
||||
assert.Equal("test", rec.Body.String())
|
||||
|
||||
// Decompress
|
||||
body := `{"name": "echo"}`
|
||||
@ -67,11 +95,14 @@ func TestDecompressDefaultConfig(t *testing.T) {
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
h(c)
|
||||
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||
|
||||
err := h(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||
b, err := ioutil.ReadAll(req.Body)
|
||||
assert.NoError(err)
|
||||
assert.Equal(body, string(b))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, body, string(b))
|
||||
}
|
||||
|
||||
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
|
||||
@ -82,7 +113,9 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec := httptest.NewRecorder()
|
||||
e.NewContext(req, rec)
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||
b, err := ioutil.ReadAll(req.Body)
|
||||
assert.NoError(t, err)
|
||||
@ -99,7 +132,10 @@ func TestDecompressNoContent(t *testing.T) {
|
||||
h := Decompress()(func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
})
|
||||
if assert.NoError(t, h(c)) {
|
||||
|
||||
err := h(c)
|
||||
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
|
||||
assert.Equal(t, 0, len(rec.Body.Bytes()))
|
||||
@ -115,7 +151,9 @@ func TestDecompressErrorReturned(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, rec.Code)
|
||||
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
|
||||
}
|
||||
@ -132,7 +170,9 @@ func TestDecompressSkipper(t *testing.T) {
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
|
||||
reqBody, err := ioutil.ReadAll(c.Request().Body)
|
||||
assert.NoError(t, err)
|
||||
@ -161,7 +201,9 @@ func TestDecompressPoolError(t *testing.T) {
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
|
||||
reqBody, err := ioutil.ReadAll(c.Request().Body)
|
||||
assert.NoError(t, err)
|
||||
|
148
middleware/extractor.go
Normal file
148
middleware/extractor.go
Normal file
@ -0,0 +1,148 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/labstack/echo/v4"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrExtractionValueMissing denotes an error raised when value could not be extracted from request
|
||||
var ErrExtractionValueMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed value")
|
||||
|
||||
// ExtractorType is enum type for where extractor will take its data
|
||||
type ExtractorType string
|
||||
|
||||
const (
|
||||
// HeaderExtractor tells extractor to take values from request header
|
||||
HeaderExtractor ExtractorType = "header"
|
||||
// QueryExtractor tells extractor to take values from request query parameters
|
||||
QueryExtractor ExtractorType = "query"
|
||||
// ParamExtractor tells extractor to take values from request route parameters
|
||||
ParamExtractor ExtractorType = "param"
|
||||
// CookieExtractor tells extractor to take values from request cookie
|
||||
CookieExtractor ExtractorType = "cookie"
|
||||
// FormExtractor tells extractor to take values from request form fields
|
||||
FormExtractor ExtractorType = "form"
|
||||
)
|
||||
|
||||
func createExtractors(lookups string) ([]valuesExtractor, error) {
|
||||
sources := strings.Split(lookups, ",")
|
||||
var extractors []valuesExtractor
|
||||
for _, source := range sources {
|
||||
parts := strings.Split(source, ":")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
|
||||
}
|
||||
|
||||
switch ExtractorType(parts[0]) {
|
||||
case QueryExtractor:
|
||||
extractors = append(extractors, valuesFromQuery(parts[1]))
|
||||
case ParamExtractor:
|
||||
extractors = append(extractors, valuesFromParam(parts[1]))
|
||||
case CookieExtractor:
|
||||
extractors = append(extractors, valuesFromCookie(parts[1]))
|
||||
case FormExtractor:
|
||||
extractors = append(extractors, valuesFromForm(parts[1]))
|
||||
case HeaderExtractor:
|
||||
prefix := ""
|
||||
if len(parts) > 2 {
|
||||
prefix = parts[2]
|
||||
}
|
||||
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
|
||||
}
|
||||
}
|
||||
return extractors, nil
|
||||
}
|
||||
|
||||
// valuesFromHeader returns a functions that extracts values from the request header.
|
||||
func valuesFromHeader(header string, valuePrefix string) valuesExtractor {
|
||||
prefixLen := len(valuePrefix)
|
||||
return func(c echo.Context) ([]string, ExtractorType, error) {
|
||||
values := textproto.MIMEHeader(c.Request().Header).Values(header)
|
||||
if len(values) == 0 {
|
||||
return nil, HeaderExtractor, ErrExtractionValueMissing
|
||||
}
|
||||
|
||||
result := make([]string, 0)
|
||||
for _, value := range values {
|
||||
if prefixLen == 0 {
|
||||
result = append(result, value)
|
||||
continue
|
||||
}
|
||||
if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
|
||||
result = append(result, value[prefixLen:])
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, HeaderExtractor, ErrExtractionValueMissing
|
||||
}
|
||||
return result, HeaderExtractor, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromQuery returns a function that extracts values from the query string.
|
||||
func valuesFromQuery(param string) valuesExtractor {
|
||||
return func(c echo.Context) ([]string, ExtractorType, error) {
|
||||
result := c.QueryParams()[param]
|
||||
if len(result) == 0 {
|
||||
return nil, QueryExtractor, ErrExtractionValueMissing
|
||||
}
|
||||
return result, QueryExtractor, nil
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromParam returns a function that extracts values from the url param string.
|
||||
func valuesFromParam(param string) valuesExtractor {
|
||||
return func(c echo.Context) ([]string, ExtractorType, error) {
|
||||
result := make([]string, 0)
|
||||
for _, p := range c.PathParams() {
|
||||
if param == p.Name {
|
||||
result = append(result, p.Value)
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, ParamExtractor, ErrExtractionValueMissing
|
||||
}
|
||||
return result, ParamExtractor, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromCookie returns a function that extracts values from the named cookie.
|
||||
func valuesFromCookie(name string) valuesExtractor {
|
||||
return func(c echo.Context) ([]string, ExtractorType, error) {
|
||||
cookies := c.Cookies()
|
||||
if len(cookies) == 0 {
|
||||
return nil, CookieExtractor, ErrExtractionValueMissing
|
||||
}
|
||||
|
||||
result := make([]string, 0)
|
||||
for _, cookie := range cookies {
|
||||
if name == cookie.Name {
|
||||
result = append(result, cookie.Value)
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, CookieExtractor, ErrExtractionValueMissing
|
||||
}
|
||||
return result, CookieExtractor, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromForm returns a function that extracts values from the form field.
|
||||
func valuesFromForm(name string) valuesExtractor {
|
||||
return func(c echo.Context) ([]string, ExtractorType, error) {
|
||||
if err := c.Request().ParseForm(); err != nil {
|
||||
return nil, FormExtractor, fmt.Errorf("valuesFromForm parse form failed: %w", err)
|
||||
}
|
||||
values := c.Request().Form[name]
|
||||
if len(values) == 0 {
|
||||
return nil, FormExtractor, ErrExtractionValueMissing
|
||||
}
|
||||
|
||||
result := append([]string{}, values...)
|
||||
return result, FormExtractor, nil
|
||||
}
|
||||
}
|
498
middleware/extractor_test.go
Normal file
498
middleware/extractor_test.go
Normal file
@ -0,0 +1,498 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateExtractors(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRequest func() *http.Request
|
||||
givenPathParams echo.PathParams
|
||||
whenLoopups string
|
||||
expectValues []string
|
||||
expectExtractorType ExtractorType
|
||||
expectCreateError string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, header",
|
||||
givenRequest: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer token")
|
||||
return req
|
||||
},
|
||||
whenLoopups: "header:Authorization:Bearer ",
|
||||
expectValues: []string{"token"},
|
||||
expectExtractorType: HeaderExtractor,
|
||||
},
|
||||
{
|
||||
name: "ok, form",
|
||||
givenRequest: func() *http.Request {
|
||||
f := make(url.Values)
|
||||
f.Set("name", "Jon Snow")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
|
||||
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
return req
|
||||
},
|
||||
whenLoopups: "form:name",
|
||||
expectValues: []string{"Jon Snow"},
|
||||
expectExtractorType: FormExtractor,
|
||||
},
|
||||
{
|
||||
name: "ok, cookie",
|
||||
givenRequest: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderCookie, "_csrf=token")
|
||||
return req
|
||||
},
|
||||
whenLoopups: "cookie:_csrf",
|
||||
expectValues: []string{"token"},
|
||||
expectExtractorType: CookieExtractor,
|
||||
},
|
||||
{
|
||||
name: "ok, param",
|
||||
givenPathParams: echo.PathParams{
|
||||
{Name: "id", Value: "123"},
|
||||
},
|
||||
whenLoopups: "param:id",
|
||||
expectValues: []string{"123"},
|
||||
expectExtractorType: ParamExtractor,
|
||||
},
|
||||
{
|
||||
name: "ok, query",
|
||||
givenRequest: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
|
||||
return req
|
||||
},
|
||||
whenLoopups: "query:id",
|
||||
expectValues: []string{"999"},
|
||||
expectExtractorType: QueryExtractor,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if tc.givenRequest != nil {
|
||||
req = tc.givenRequest()
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec).(echo.EditableContext)
|
||||
if tc.givenPathParams != nil {
|
||||
c.SetRawPathParams(&tc.givenPathParams)
|
||||
}
|
||||
|
||||
extractors, err := createExtractors(tc.whenLoopups)
|
||||
if tc.expectCreateError != "" {
|
||||
assert.EqualError(t, err, tc.expectCreateError)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, e := range extractors {
|
||||
values, eType, eErr := e(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, tc.expectExtractorType, eType)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, eErr, tc.expectError)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, eErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesFromHeader(t *testing.T) {
|
||||
exampleRequest := func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||
}
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRequest func(req *http.Request)
|
||||
whenName string
|
||||
whenValuePrefix string
|
||||
expectValues []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, single value",
|
||||
givenRequest: exampleRequest,
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "basic ",
|
||||
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
|
||||
},
|
||||
{
|
||||
name: "ok, single value, case insensitive",
|
||||
givenRequest: exampleRequest,
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "Basic ",
|
||||
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
|
||||
},
|
||||
{
|
||||
name: "ok, multiple value",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
|
||||
},
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "basic ",
|
||||
expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
|
||||
},
|
||||
{
|
||||
name: "ok, empty prefix",
|
||||
givenRequest: exampleRequest,
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "",
|
||||
expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="},
|
||||
},
|
||||
{
|
||||
name: "nok, no matching due different prefix",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
|
||||
},
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "Bearer ",
|
||||
expectError: ErrExtractionValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "nok, no matching due different prefix",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
|
||||
},
|
||||
whenName: echo.HeaderWWWAuthenticate,
|
||||
whenValuePrefix: "",
|
||||
expectError: ErrExtractionValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "nok, no headers",
|
||||
givenRequest: nil,
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "basic ",
|
||||
expectError: ErrExtractionValueMissing.Error(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if tc.givenRequest != nil {
|
||||
tc.givenRequest(req)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
|
||||
|
||||
values, eType, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, HeaderExtractor, eType)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesFromQuery(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenQueryPart string
|
||||
whenName string
|
||||
expectValues []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, single value",
|
||||
givenQueryPart: "?id=123&name=test",
|
||||
whenName: "id",
|
||||
expectValues: []string{"123"},
|
||||
},
|
||||
{
|
||||
name: "ok, multiple value",
|
||||
givenQueryPart: "?id=123&id=456&name=test",
|
||||
whenName: "id",
|
||||
expectValues: []string{"123", "456"},
|
||||
},
|
||||
{
|
||||
name: "nok, missing value",
|
||||
givenQueryPart: "?id=123&name=test",
|
||||
whenName: "nope",
|
||||
expectError: ErrExtractionValueMissing.Error(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
extractor := valuesFromQuery(tc.whenName)
|
||||
|
||||
values, eType, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, QueryExtractor, eType)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesFromParam(t *testing.T) {
|
||||
examplePathParams := echo.PathParams{
|
||||
{Name: "id", Value: "123"},
|
||||
{Name: "gid", Value: "456"},
|
||||
{Name: "gid", Value: "789"},
|
||||
}
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenPathParams echo.PathParams
|
||||
whenName string
|
||||
expectValues []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, single value",
|
||||
givenPathParams: examplePathParams,
|
||||
whenName: "id",
|
||||
expectValues: []string{"123"},
|
||||
},
|
||||
{
|
||||
name: "ok, multiple value",
|
||||
givenPathParams: examplePathParams,
|
||||
whenName: "gid",
|
||||
expectValues: []string{"456", "789"},
|
||||
},
|
||||
{
|
||||
name: "nok, no values",
|
||||
givenPathParams: nil,
|
||||
whenName: "nope",
|
||||
expectValues: nil,
|
||||
expectError: ErrExtractionValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "nok, no matching value",
|
||||
givenPathParams: examplePathParams,
|
||||
whenName: "nope",
|
||||
expectValues: nil,
|
||||
expectError: ErrExtractionValueMissing.Error(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec).(echo.EditableContext)
|
||||
if tc.givenPathParams != nil {
|
||||
c.SetRawPathParams(&tc.givenPathParams)
|
||||
}
|
||||
|
||||
extractor := valuesFromParam(tc.whenName)
|
||||
|
||||
values, eType, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, ParamExtractor, eType)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesFromCookie(t *testing.T) {
|
||||
exampleRequest := func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderCookie, "_csrf=token")
|
||||
}
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRequest func(req *http.Request)
|
||||
whenName string
|
||||
expectValues []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, single value",
|
||||
givenRequest: exampleRequest,
|
||||
whenName: "_csrf",
|
||||
expectValues: []string{"token"},
|
||||
},
|
||||
{
|
||||
name: "ok, multiple value",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Add(echo.HeaderCookie, "_csrf=token")
|
||||
req.Header.Add(echo.HeaderCookie, "_csrf=token2")
|
||||
},
|
||||
whenName: "_csrf",
|
||||
expectValues: []string{"token", "token2"},
|
||||
},
|
||||
{
|
||||
name: "nok, no matching cookie",
|
||||
givenRequest: exampleRequest,
|
||||
whenName: "xxx",
|
||||
expectValues: nil,
|
||||
expectError: ErrExtractionValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "nok, no cookies at all",
|
||||
givenRequest: nil,
|
||||
whenName: "xxx",
|
||||
expectValues: nil,
|
||||
expectError: ErrExtractionValueMissing.Error(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if tc.givenRequest != nil {
|
||||
tc.givenRequest(req)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
extractor := valuesFromCookie(tc.whenName)
|
||||
|
||||
values, eType, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, CookieExtractor, eType)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesFromForm(t *testing.T) {
|
||||
examplePostFormRequest := func(mod func(v *url.Values)) *http.Request {
|
||||
f := make(url.Values)
|
||||
f.Set("name", "Jon Snow")
|
||||
f.Set("emails[]", "jon@labstack.com")
|
||||
if mod != nil {
|
||||
mod(&f)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
|
||||
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
|
||||
return req
|
||||
}
|
||||
exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request {
|
||||
f := make(url.Values)
|
||||
f.Set("name", "Jon Snow")
|
||||
f.Set("emails[]", "jon@labstack.com")
|
||||
if mod != nil {
|
||||
mod(&f)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil)
|
||||
return req
|
||||
}
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRequest *http.Request
|
||||
whenName string
|
||||
expectValues []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, POST form, single value",
|
||||
givenRequest: examplePostFormRequest(nil),
|
||||
whenName: "emails[]",
|
||||
expectValues: []string{"jon@labstack.com"},
|
||||
},
|
||||
{
|
||||
name: "ok, POST form, multiple value",
|
||||
givenRequest: examplePostFormRequest(func(v *url.Values) {
|
||||
v.Add("emails[]", "snow@labstack.com")
|
||||
}),
|
||||
whenName: "emails[]",
|
||||
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
|
||||
},
|
||||
{
|
||||
name: "ok, GET form, single value",
|
||||
givenRequest: exampleGetFormRequest(nil),
|
||||
whenName: "emails[]",
|
||||
expectValues: []string{"jon@labstack.com"},
|
||||
},
|
||||
{
|
||||
name: "ok, GET form, multiple value",
|
||||
givenRequest: examplePostFormRequest(func(v *url.Values) {
|
||||
v.Add("emails[]", "snow@labstack.com")
|
||||
}),
|
||||
whenName: "emails[]",
|
||||
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
|
||||
},
|
||||
{
|
||||
name: "nok, POST form, value missing",
|
||||
givenRequest: examplePostFormRequest(nil),
|
||||
whenName: "nope",
|
||||
expectError: ErrExtractionValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "nok, POST form, form parsing error",
|
||||
givenRequest: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Body = nil
|
||||
return req
|
||||
}(),
|
||||
whenName: "name",
|
||||
expectError: "valuesFromForm parse form failed: missing form body",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := tc.givenRequest
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
extractor := valuesFromForm(tc.whenName)
|
||||
|
||||
values, eType, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, FormExtractor, eType)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,134 +1,84 @@
|
||||
// +build go1.15
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/labstack/echo/v4"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type (
|
||||
// JWTConfig defines the config for JWT middleware.
|
||||
JWTConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// JWTConfig defines the config for JWT middleware.
|
||||
type JWTConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// BeforeFunc defines a function which is executed just before the middleware.
|
||||
BeforeFunc BeforeFunc
|
||||
// BeforeFunc defines a function which is executed just before the middleware.
|
||||
BeforeFunc BeforeFunc
|
||||
|
||||
// SuccessHandler defines a function which is executed for a valid token.
|
||||
SuccessHandler JWTSuccessHandler
|
||||
// SuccessHandler defines a function which is executed for a valid token.
|
||||
SuccessHandler JWTSuccessHandler
|
||||
|
||||
// ErrorHandler defines a function which is executed for an invalid token.
|
||||
// It may be used to define a custom JWT error.
|
||||
ErrorHandler JWTErrorHandler
|
||||
// ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
|
||||
// function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
|
||||
// It may be used to define a custom JWT error.
|
||||
//
|
||||
// Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
|
||||
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
|
||||
// In that case you can use ErrorHandler to set default public JWT token value to request and continue with handler chain.
|
||||
ErrorHandler JWTErrorHandlerWithContext
|
||||
|
||||
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
|
||||
ErrorHandlerWithContext JWTErrorHandlerWithContext
|
||||
// Context key to store user information from the token into context.
|
||||
// Optional. Default value "user".
|
||||
ContextKey string
|
||||
|
||||
// Signing key to validate token.
|
||||
// This is one of the three options to provide a token validation key.
|
||||
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
|
||||
// Required if neither user-defined KeyFunc nor SigningKeys is provided.
|
||||
SigningKey interface{}
|
||||
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
|
||||
// to extract token(s) from the request.
|
||||
// Optional. Default value "header:Authorization:Bearer ".
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "query:<name>"
|
||||
// - "param:<name>"
|
||||
// - "cookie:<name>"
|
||||
// - "form:<name>"
|
||||
// Multiple sources example:
|
||||
// - "header:Authorization,cookie:myowncookie"
|
||||
TokenLookup string
|
||||
|
||||
// Map of signing keys to validate token with kid field usage.
|
||||
// This is one of the three options to provide a token validation key.
|
||||
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
|
||||
// Required if neither user-defined KeyFunc nor SigningKey is provided.
|
||||
SigningKeys map[string]interface{}
|
||||
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
|
||||
// parsing fails or parsed token is invalid.
|
||||
// NB: could be called multiple times per request when token lookup is able to extract multiple token values (i.e. multiple Authorization headers)
|
||||
// See `jwt_external_test.go` for example implementation using `github.com/golang-jwt/jwt` as JWT implementation library
|
||||
ParseTokenFunc func(c echo.Context, auth string) (interface{}, error)
|
||||
}
|
||||
|
||||
// Signing method used to check the token's signing algorithm.
|
||||
// Optional. Default value HS256.
|
||||
SigningMethod string
|
||||
// JWTSuccessHandler defines a function which is executed for a valid token.
|
||||
type JWTSuccessHandler func(c echo.Context)
|
||||
|
||||
// Context key to store user information from the token into context.
|
||||
// Optional. Default value "user".
|
||||
ContextKey string
|
||||
// JWTErrorHandler defines a function which is executed for an invalid token.
|
||||
type JWTErrorHandler func(err error) error
|
||||
|
||||
// Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation.
|
||||
// Not used if custom ParseTokenFunc is set.
|
||||
// Optional. Default value jwt.MapClaims
|
||||
Claims jwt.Claims
|
||||
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
|
||||
type JWTErrorHandlerWithContext func(c echo.Context, err error) error
|
||||
|
||||
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
|
||||
// to extract token from the request.
|
||||
// Optional. Default value "header:Authorization".
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "query:<name>"
|
||||
// - "param:<name>"
|
||||
// - "cookie:<name>"
|
||||
// - "form:<name>"
|
||||
// Multiply sources example:
|
||||
// - "header: Authorization,cookie: myowncookie"
|
||||
type valuesExtractor func(c echo.Context) ([]string, ExtractorType, error)
|
||||
|
||||
TokenLookup string
|
||||
|
||||
// AuthScheme to be used in the Authorization header.
|
||||
// Optional. Default value "Bearer".
|
||||
AuthScheme string
|
||||
|
||||
// KeyFunc defines a user-defined function that supplies the public key for a token validation.
|
||||
// The function shall take care of verifying the signing algorithm and selecting the proper key.
|
||||
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
|
||||
// Used by default ParseTokenFunc implementation.
|
||||
//
|
||||
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
|
||||
// This is one of the three options to provide a token validation key.
|
||||
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
|
||||
// Required if neither SigningKeys nor SigningKey is provided.
|
||||
// Not used if custom ParseTokenFunc is set.
|
||||
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
|
||||
KeyFunc jwt.Keyfunc
|
||||
|
||||
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
|
||||
// parsing fails or parsed token is invalid.
|
||||
// Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library
|
||||
ParseTokenFunc func(auth string, c echo.Context) (interface{}, error)
|
||||
}
|
||||
|
||||
// JWTSuccessHandler defines a function which is executed for a valid token.
|
||||
JWTSuccessHandler func(echo.Context)
|
||||
|
||||
// JWTErrorHandler defines a function which is executed for an invalid token.
|
||||
JWTErrorHandler func(error) error
|
||||
|
||||
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
|
||||
JWTErrorHandlerWithContext func(error, echo.Context) error
|
||||
|
||||
jwtExtractor func(echo.Context) (string, error)
|
||||
)
|
||||
|
||||
// Algorithms
|
||||
const (
|
||||
// AlgorithmHS256 is token signing algorithm
|
||||
AlgorithmHS256 = "HS256"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
|
||||
ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
|
||||
)
|
||||
// ErrJWTMissing denotes an error raised when JWT token value could not be extracted from request
|
||||
var ErrJWTMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing or malformed jwt")
|
||||
|
||||
var (
|
||||
// DefaultJWTConfig is the default JWT auth middleware config.
|
||||
DefaultJWTConfig = JWTConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
SigningMethod: AlgorithmHS256,
|
||||
ContextKey: "user",
|
||||
TokenLookup: "header:" + echo.HeaderAuthorization,
|
||||
AuthScheme: "Bearer",
|
||||
Claims: jwt.MapClaims{},
|
||||
KeyFunc: nil,
|
||||
}
|
||||
)
|
||||
// ErrJWTInvalid denotes an error raised when JWT token value is invalid or expired
|
||||
var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
|
||||
|
||||
// DefaultJWTConfig is the default JWT auth middleware config.
|
||||
var DefaultJWTConfig = JWTConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
ContextKey: "user",
|
||||
TokenLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
|
||||
}
|
||||
|
||||
// JWT returns a JSON Web Token (JWT) auth middleware.
|
||||
//
|
||||
@ -137,64 +87,43 @@ var (
|
||||
// For missing token, it returns "400 - Bad Request" error.
|
||||
//
|
||||
// See: https://jwt.io/introduction
|
||||
// See `JWTConfig.TokenLookup`
|
||||
func JWT(key interface{}) echo.MiddlewareFunc {
|
||||
func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc {
|
||||
c := DefaultJWTConfig
|
||||
c.SigningKey = key
|
||||
c.ParseTokenFunc = parseTokenFunc
|
||||
return JWTWithConfig(c)
|
||||
}
|
||||
|
||||
// JWTWithConfig returns a JWT auth middleware with config.
|
||||
// See: `JWT()`.
|
||||
// JWTWithConfig returns a JSON Web Token (JWT) auth middleware or panics if configuration is invalid.
|
||||
//
|
||||
// For valid token, it sets the user in context and calls next handler.
|
||||
// For invalid token, it returns "401 - Unauthorized" error.
|
||||
// For missing token, it returns "400 - Bad Request" error.
|
||||
//
|
||||
// See: https://jwt.io/introduction
|
||||
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts JWTConfig to middleware or returns an error for invalid configuration
|
||||
func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultJWTConfig.Skipper
|
||||
}
|
||||
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil {
|
||||
panic("echo: jwt middleware requires signing key")
|
||||
}
|
||||
if config.SigningMethod == "" {
|
||||
config.SigningMethod = DefaultJWTConfig.SigningMethod
|
||||
if config.ParseTokenFunc == nil {
|
||||
return nil, errors.New("echo jwt middleware requires parse token function")
|
||||
}
|
||||
if config.ContextKey == "" {
|
||||
config.ContextKey = DefaultJWTConfig.ContextKey
|
||||
}
|
||||
if config.Claims == nil {
|
||||
config.Claims = DefaultJWTConfig.Claims
|
||||
}
|
||||
if config.TokenLookup == "" {
|
||||
config.TokenLookup = DefaultJWTConfig.TokenLookup
|
||||
}
|
||||
if config.AuthScheme == "" {
|
||||
config.AuthScheme = DefaultJWTConfig.AuthScheme
|
||||
extractors, err := createExtractors(config.TokenLookup)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("echo jwt middleware could not create token extractor: %w", err)
|
||||
}
|
||||
if config.KeyFunc == nil {
|
||||
config.KeyFunc = config.defaultKeyFunc
|
||||
}
|
||||
if config.ParseTokenFunc == nil {
|
||||
config.ParseTokenFunc = config.defaultParseToken
|
||||
}
|
||||
|
||||
// Initialize
|
||||
// Split sources
|
||||
sources := strings.Split(config.TokenLookup, ",")
|
||||
var extractors []jwtExtractor
|
||||
for _, source := range sources {
|
||||
parts := strings.Split(source, ":")
|
||||
|
||||
switch parts[0] {
|
||||
case "query":
|
||||
extractors = append(extractors, jwtFromQuery(parts[1]))
|
||||
case "param":
|
||||
extractors = append(extractors, jwtFromParam(parts[1]))
|
||||
case "cookie":
|
||||
extractors = append(extractors, jwtFromCookie(parts[1]))
|
||||
case "form":
|
||||
extractors = append(extractors, jwtFromForm(parts[1]))
|
||||
case "header":
|
||||
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
|
||||
}
|
||||
if len(extractors) == 0 {
|
||||
return nil, errors.New("echo jwt middleware could not create extractors from TokenLookup string")
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@ -206,142 +135,55 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
||||
if config.BeforeFunc != nil {
|
||||
config.BeforeFunc(c)
|
||||
}
|
||||
var auth string
|
||||
var err error
|
||||
var lastExtractorErr error
|
||||
var lastTokenErr error
|
||||
for _, extractor := range extractors {
|
||||
// Extract token from extractor, if it's not fail break the loop and
|
||||
// set auth
|
||||
auth, err = extractor(c)
|
||||
if err == nil {
|
||||
break
|
||||
auths, _, extrErr := extractor(c)
|
||||
if extrErr != nil {
|
||||
lastExtractorErr = extrErr
|
||||
continue
|
||||
}
|
||||
for _, auth := range auths {
|
||||
token, err := config.ParseTokenFunc(c, auth)
|
||||
if err != nil {
|
||||
lastTokenErr = err
|
||||
continue
|
||||
}
|
||||
// Store user information from token into context.
|
||||
c.Set(config.ContextKey, token)
|
||||
if config.SuccessHandler != nil {
|
||||
config.SuccessHandler(c)
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
// If none of extractor has a token, handle error
|
||||
if err != nil {
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(err)
|
||||
}
|
||||
|
||||
if config.ErrorHandlerWithContext != nil {
|
||||
return config.ErrorHandlerWithContext(err, c)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := config.ParseTokenFunc(auth, c)
|
||||
// prioritize token errors over extracting errors
|
||||
err := lastTokenErr
|
||||
if err == nil {
|
||||
// Store user information from token into context.
|
||||
c.Set(config.ContextKey, token)
|
||||
if config.SuccessHandler != nil {
|
||||
config.SuccessHandler(c)
|
||||
err = lastExtractorErr
|
||||
}
|
||||
if config.ErrorHandler != nil {
|
||||
if err == ErrExtractionValueMissing {
|
||||
err = ErrJWTMissing
|
||||
}
|
||||
// Allow error handler to swallow the error and continue handler chain execution
|
||||
// Useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
|
||||
// In that case you can use ErrorHandler to set default public token to request and continue with handler chain
|
||||
if handledErr := config.ErrorHandler(c, err); handledErr != nil {
|
||||
return handledErr
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(err)
|
||||
}
|
||||
if config.ErrorHandlerWithContext != nil {
|
||||
return config.ErrorHandlerWithContext(err, c)
|
||||
if err == ErrExtractionValueMissing {
|
||||
return ErrJWTMissing
|
||||
}
|
||||
// everything else goes under http.StatusUnauthorized to avoid exposing JWT internals with generic error
|
||||
return &echo.HTTPError{
|
||||
Code: ErrJWTInvalid.Code,
|
||||
Message: ErrJWTInvalid.Message,
|
||||
Internal: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) {
|
||||
token := new(jwt.Token)
|
||||
var err error
|
||||
// Issue #647, #656
|
||||
if _, ok := config.Claims.(jwt.MapClaims); ok {
|
||||
token, err = jwt.Parse(auth, config.KeyFunc)
|
||||
} else {
|
||||
t := reflect.ValueOf(config.Claims).Type().Elem()
|
||||
claims := reflect.New(t).Interface().(jwt.Claims)
|
||||
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// defaultKeyFunc returns a signing key of the given token.
|
||||
func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
|
||||
// Check the signing method
|
||||
if t.Method.Alg() != config.SigningMethod {
|
||||
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
|
||||
}
|
||||
if len(config.SigningKeys) > 0 {
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if key, ok := config.SigningKeys[kid]; ok {
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
|
||||
}
|
||||
|
||||
return config.SigningKey, nil
|
||||
}
|
||||
|
||||
// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
|
||||
func jwtFromHeader(header string, authScheme string) jwtExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
auth := c.Request().Header.Get(header)
|
||||
l := len(authScheme)
|
||||
if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) {
|
||||
return auth[l+1:], nil
|
||||
}
|
||||
return "", ErrJWTMissing
|
||||
}
|
||||
}
|
||||
|
||||
// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string.
|
||||
func jwtFromQuery(param string) jwtExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
token := c.QueryParam(param)
|
||||
if token == "" {
|
||||
return "", ErrJWTMissing
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string.
|
||||
func jwtFromParam(param string) jwtExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
token := c.Param(param)
|
||||
if token == "" {
|
||||
return "", ErrJWTMissing
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie.
|
||||
func jwtFromCookie(name string) jwtExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
cookie, err := c.Cookie(name)
|
||||
if err != nil {
|
||||
return "", ErrJWTMissing
|
||||
}
|
||||
return cookie.Value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
|
||||
func jwtFromForm(name string) jwtExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
field := c.FormValue(name)
|
||||
if field == "" {
|
||||
return "", ErrJWTMissing
|
||||
}
|
||||
return field, nil
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
76
middleware/jwt_external_test.go
Normal file
76
middleware/jwt_external_test.go
Normal file
@ -0,0 +1,76 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
)
|
||||
|
||||
// CreateJWTGoParseTokenFunc creates JWTGo implementation for ParseTokenFunc
|
||||
//
|
||||
// signingKey is signing key to validate token.
|
||||
// This is one of the options to provide a token validation key.
|
||||
// The order of precedence is a user-defined SigningKeys and SigningKey.
|
||||
// Required if signingKeys is not provided.
|
||||
//
|
||||
// signingKeys is Map of signing keys to validate token with kid field usage.
|
||||
// This is one of the options to provide a token validation key.
|
||||
// The order of precedence is a user-defined SigningKeys and SigningKey.
|
||||
// Required if signingKey is not provided
|
||||
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) {
|
||||
// keyFunc defines a user-defined function that supplies the public key for a token validation.
|
||||
// The function shall take care of verifying the signing algorithm and selecting the proper key.
|
||||
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
|
||||
keyFunc := func(t *jwt.Token) (interface{}, error) {
|
||||
if t.Method.Alg() != middleware.AlgorithmHS256 {
|
||||
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
|
||||
}
|
||||
if len(signingKeys) == 0 {
|
||||
return signingKey, nil
|
||||
}
|
||||
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if key, ok := signingKeys[kid]; ok {
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
|
||||
}
|
||||
|
||||
return func(c echo.Context, auth string) (interface{}, error) {
|
||||
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleJWTConfig_withJWTGoAsTokenParser() {
|
||||
mw := middleware.JWTWithConfig(middleware.JWTConfig{
|
||||
ParseTokenFunc: CreateJWTGoParseTokenFunc([]byte("secret"), nil),
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
e.Use(mw)
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
user := c.Get("user").(*jwt.Token)
|
||||
return c.JSON(http.StatusTeapot, user.Claims)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
fmt.Printf("status: %v, body: %v", res.Code, res.Body.String())
|
||||
// Output: status: 418, body: {"admin":true,"name":"John Doe","sub":"1234567890"}
|
||||
}
|
@ -1,5 +1,3 @@
|
||||
// +build go1.15
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
@ -11,11 +9,32 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) {
|
||||
// This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running
|
||||
keyFunc := func(t *jwt.Token) (interface{}, error) {
|
||||
if t.Method.Alg() != signingMethod {
|
||||
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
|
||||
}
|
||||
return signingKey, nil
|
||||
}
|
||||
|
||||
return func(c echo.Context, auth string) (interface{}, error) {
|
||||
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// jwtCustomInfo defines some custom types we're going to use within our tokens.
|
||||
type jwtCustomInfo struct {
|
||||
Name string `json:"name"`
|
||||
@ -28,43 +47,7 @@ type jwtCustomClaims struct {
|
||||
jwtCustomInfo
|
||||
}
|
||||
|
||||
func TestJWTRace(t *testing.T) {
|
||||
e := echo.New()
|
||||
handler := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
|
||||
raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss"
|
||||
validKey := []byte("secret")
|
||||
|
||||
h := JWTWithConfig(JWTConfig{
|
||||
Claims: &jwtCustomClaims{},
|
||||
SigningKey: validKey,
|
||||
})(handler)
|
||||
|
||||
makeReq := func(token string) echo.Context {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token)
|
||||
c := e.NewContext(req, res)
|
||||
assert.NoError(t, h(c))
|
||||
return c
|
||||
}
|
||||
|
||||
c := makeReq(initialToken)
|
||||
user := c.Get("user").(*jwt.Token)
|
||||
claims := user.Claims.(*jwtCustomClaims)
|
||||
assert.Equal(t, claims.Name, "John Doe")
|
||||
|
||||
makeReq(raceToken)
|
||||
user = c.Get("user").(*jwt.Token)
|
||||
claims = user.Claims.(*jwtCustomClaims)
|
||||
// Initial context should still be "John Doe", not "Race Condition"
|
||||
assert.Equal(t, claims.Name, "John Doe")
|
||||
assert.Equal(t, claims.Admin, true)
|
||||
}
|
||||
|
||||
func TestJWT(t *testing.T) {
|
||||
func TestJWT_combinations(t *testing.T) {
|
||||
e := echo.New()
|
||||
handler := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
@ -72,344 +55,236 @@ func TestJWT(t *testing.T) {
|
||||
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
|
||||
validKey := []byte("secret")
|
||||
invalidKey := []byte("invalid-key")
|
||||
validAuth := DefaultJWTConfig.AuthScheme + " " + token
|
||||
validAuth := "Bearer " + token
|
||||
|
||||
for _, tc := range []struct {
|
||||
expPanic bool
|
||||
expErrCode int // 0 for Success
|
||||
var testCases = []struct {
|
||||
name string
|
||||
config JWTConfig
|
||||
reqURL string // "/" if empty
|
||||
hdrAuth string
|
||||
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
|
||||
formValues map[string]string
|
||||
info string
|
||||
expPanic bool
|
||||
expErrCode int // 0 for Success
|
||||
}{
|
||||
{
|
||||
expPanic: true,
|
||||
info: "No signing key provided",
|
||||
},
|
||||
{
|
||||
expErrCode: http.StatusBadRequest,
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
SigningMethod: "RS256",
|
||||
},
|
||||
info: "Unexpected signing method",
|
||||
name: "No signing key provided",
|
||||
},
|
||||
{
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{SigningKey: invalidKey},
|
||||
info: "Invalid key",
|
||||
config: JWTConfig{
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey),
|
||||
},
|
||||
name: "Unexpected signing method",
|
||||
},
|
||||
{
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, invalidKey),
|
||||
},
|
||||
name: "Invalid key",
|
||||
},
|
||||
{
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{SigningKey: validKey},
|
||||
info: "Valid JWT",
|
||||
config: JWTConfig{
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
},
|
||||
name: "Valid JWT",
|
||||
},
|
||||
{
|
||||
hdrAuth: "Token" + " " + token,
|
||||
config: JWTConfig{AuthScheme: "Token", SigningKey: validKey},
|
||||
info: "Valid JWT with custom AuthScheme",
|
||||
config: JWTConfig{
|
||||
TokenLookup: "header:" + echo.HeaderAuthorization + ":Token ",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
},
|
||||
name: "Valid JWT with custom AuthScheme",
|
||||
},
|
||||
{
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{
|
||||
Claims: &jwtCustomClaims{},
|
||||
SigningKey: []byte("secret"),
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
|
||||
},
|
||||
info: "Valid JWT with custom claims",
|
||||
name: "Valid JWT with custom claims",
|
||||
},
|
||||
{
|
||||
hdrAuth: "invalid-auth",
|
||||
expErrCode: http.StatusBadRequest,
|
||||
config: JWTConfig{SigningKey: validKey},
|
||||
info: "Invalid Authorization header",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{SigningKey: validKey},
|
||||
expErrCode: http.StatusBadRequest,
|
||||
info: "Empty header auth field",
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
config: JWTConfig{
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
},
|
||||
name: "Invalid Authorization header",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "query:jwt",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
},
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
name: "Empty header auth field",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "query:jwt",
|
||||
},
|
||||
reqURL: "/?a=b&jwt=" + token,
|
||||
info: "Valid query method",
|
||||
name: "Valid query method",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "query:jwt",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "query:jwt",
|
||||
},
|
||||
reqURL: "/?a=b&jwtxyz=" + token,
|
||||
expErrCode: http.StatusBadRequest,
|
||||
info: "Invalid query param name",
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
name: "Invalid query param name",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "query:jwt",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "query:jwt",
|
||||
},
|
||||
reqURL: "/?a=b&jwt=invalid-token",
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
info: "Invalid query param value",
|
||||
name: "Invalid query param value",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "query:jwt",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "query:jwt",
|
||||
},
|
||||
reqURL: "/?a=b",
|
||||
expErrCode: http.StatusBadRequest,
|
||||
info: "Empty query",
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
name: "Empty query",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "param:jwt",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "param:jwt",
|
||||
},
|
||||
reqURL: "/" + token,
|
||||
info: "Valid param method",
|
||||
name: "Valid param method",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "cookie:jwt",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "cookie:jwt",
|
||||
},
|
||||
hdrCookie: "jwt=" + token,
|
||||
info: "Valid cookie method",
|
||||
name: "Valid cookie method",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "query:jwt,cookie:jwt",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "query:jwt,cookie:jwt",
|
||||
},
|
||||
hdrCookie: "jwt=" + token,
|
||||
info: "Multiple jwt lookuop",
|
||||
name: "Multiple jwt lookuop",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "cookie:jwt",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "cookie:jwt",
|
||||
},
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
hdrCookie: "jwt=invalid",
|
||||
info: "Invalid token with cookie method",
|
||||
name: "Invalid token with cookie method",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "cookie:jwt",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "cookie:jwt",
|
||||
},
|
||||
expErrCode: http.StatusBadRequest,
|
||||
info: "Empty cookie",
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
name: "Empty cookie",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "form:jwt",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "form:jwt",
|
||||
},
|
||||
formValues: map[string]string{"jwt": token},
|
||||
info: "Valid form method",
|
||||
name: "Valid form method",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "form:jwt",
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "form:jwt",
|
||||
},
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
formValues: map[string]string{"jwt": "invalid"},
|
||||
info: "Invalid token with form method",
|
||||
name: "Invalid token with form method",
|
||||
},
|
||||
{
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "form:jwt",
|
||||
},
|
||||
expErrCode: http.StatusBadRequest,
|
||||
info: "Empty form field",
|
||||
},
|
||||
{
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{
|
||||
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
||||
return validKey, nil
|
||||
},
|
||||
},
|
||||
info: "Valid JWT with a valid key using a user-defined KeyFunc",
|
||||
},
|
||||
{
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{
|
||||
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
||||
return invalidKey, nil
|
||||
},
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
|
||||
TokenLookup: "form:jwt",
|
||||
},
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
info: "Valid JWT with an invalid key using a user-defined KeyFunc",
|
||||
name: "Empty form field",
|
||||
},
|
||||
{
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{
|
||||
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
||||
return nil, errors.New("faulty KeyFunc")
|
||||
},
|
||||
},
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
info: "Token verification does not pass using a user-defined KeyFunc",
|
||||
},
|
||||
{
|
||||
hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token,
|
||||
config: JWTConfig{SigningKey: validKey},
|
||||
info: "Valid JWT with lower case AuthScheme",
|
||||
},
|
||||
} {
|
||||
if tc.reqURL == "" {
|
||||
tc.reqURL = "/"
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
if len(tc.formValues) > 0 {
|
||||
form := url.Values{}
|
||||
for k, v := range tc.formValues {
|
||||
form.Set(k, v)
|
||||
}
|
||||
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
|
||||
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
|
||||
req.ParseForm()
|
||||
} else {
|
||||
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
|
||||
}
|
||||
res := httptest.NewRecorder()
|
||||
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
|
||||
req.Header.Set(echo.HeaderCookie, tc.hdrCookie)
|
||||
c := e.NewContext(req, res)
|
||||
|
||||
if tc.reqURL == "/"+token {
|
||||
c.SetParamNames("jwt")
|
||||
c.SetParamValues(token)
|
||||
}
|
||||
|
||||
if tc.expPanic {
|
||||
assert.Panics(t, func() {
|
||||
JWTWithConfig(tc.config)
|
||||
}, tc.info)
|
||||
continue
|
||||
}
|
||||
|
||||
if tc.expErrCode != 0 {
|
||||
h := JWTWithConfig(tc.config)(handler)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, tc.expErrCode, he.Code, tc.info)
|
||||
continue
|
||||
}
|
||||
|
||||
h := JWTWithConfig(tc.config)(handler)
|
||||
if assert.NoError(t, h(c), tc.info) {
|
||||
user := c.Get("user").(*jwt.Token)
|
||||
switch claims := user.Claims.(type) {
|
||||
case jwt.MapClaims:
|
||||
assert.Equal(t, claims["name"], "John Doe", tc.info)
|
||||
case *jwtCustomClaims:
|
||||
assert.Equal(t, claims.Name, "John Doe", tc.info)
|
||||
assert.Equal(t, claims.Admin, true, tc.info)
|
||||
default:
|
||||
panic("unexpected type of claims")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTwithKID(t *testing.T) {
|
||||
test := assert.New(t)
|
||||
|
||||
e := echo.New()
|
||||
handler := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk"
|
||||
secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM"
|
||||
wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90"
|
||||
staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o"
|
||||
validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")}
|
||||
invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")}
|
||||
staticSecret := []byte("static_secret")
|
||||
invalidStaticSecret := []byte("invalid_secret")
|
||||
|
||||
for _, tc := range []struct {
|
||||
expErrCode int // 0 for Success
|
||||
config JWTConfig
|
||||
hdrAuth string
|
||||
info string
|
||||
}{
|
||||
{
|
||||
hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
|
||||
config: JWTConfig{SigningKeys: validKeys},
|
||||
info: "First token valid",
|
||||
},
|
||||
{
|
||||
hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
|
||||
config: JWTConfig{SigningKeys: validKeys},
|
||||
info: "Second token valid",
|
||||
},
|
||||
{
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken,
|
||||
config: JWTConfig{SigningKeys: validKeys},
|
||||
info: "Wrong key id token",
|
||||
},
|
||||
{
|
||||
hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
|
||||
config: JWTConfig{SigningKey: staticSecret},
|
||||
info: "Valid static secret token",
|
||||
},
|
||||
{
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
|
||||
config: JWTConfig{SigningKey: invalidStaticSecret},
|
||||
info: "Invalid static secret",
|
||||
},
|
||||
{
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
|
||||
config: JWTConfig{SigningKeys: invalidKeys},
|
||||
info: "Invalid keys first token",
|
||||
},
|
||||
{
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
|
||||
config: JWTConfig{SigningKeys: invalidKeys},
|
||||
info: "Invalid keys second token",
|
||||
},
|
||||
} {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
|
||||
c := e.NewContext(req, res)
|
||||
|
||||
if tc.expErrCode != 0 {
|
||||
h := JWTWithConfig(tc.config)(handler)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
test.Equal(tc.expErrCode, he.Code, tc.info)
|
||||
continue
|
||||
}
|
||||
|
||||
h := JWTWithConfig(tc.config)(handler)
|
||||
if test.NoError(h(c), tc.info) {
|
||||
user := c.Get("user").(*jwt.Token)
|
||||
switch claims := user.Claims.(type) {
|
||||
case jwt.MapClaims:
|
||||
test.Equal(claims["name"], "John Doe", tc.info)
|
||||
case *jwtCustomClaims:
|
||||
test.Equal(claims.Name, "John Doe", tc.info)
|
||||
test.Equal(claims.Admin, true, tc.info)
|
||||
default:
|
||||
panic("unexpected type of claims")
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.reqURL == "" {
|
||||
tc.reqURL = "/"
|
||||
}
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
if len(tc.formValues) > 0 {
|
||||
form := url.Values{}
|
||||
for k, v := range tc.formValues {
|
||||
form.Set(k, v)
|
||||
}
|
||||
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
|
||||
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
|
||||
req.ParseForm()
|
||||
} else {
|
||||
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
|
||||
}
|
||||
res := httptest.NewRecorder()
|
||||
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
|
||||
req.Header.Set(echo.HeaderCookie, tc.hdrCookie)
|
||||
c := e.NewContext(req, res)
|
||||
|
||||
if tc.reqURL == "/"+token {
|
||||
cc := c.(echo.EditableContext)
|
||||
cc.SetPathParams(echo.PathParams{
|
||||
{Name: "jwt", Value: token},
|
||||
})
|
||||
}
|
||||
|
||||
if tc.expPanic {
|
||||
assert.Panics(t, func() {
|
||||
JWTWithConfig(tc.config)
|
||||
}, tc.name)
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expErrCode != 0 {
|
||||
h := JWTWithConfig(tc.config)(handler)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, tc.expErrCode, he.Code)
|
||||
return
|
||||
}
|
||||
|
||||
h := JWTWithConfig(tc.config)(handler)
|
||||
if assert.NoError(t, h(c), tc.name) {
|
||||
user := c.Get("user").(*jwt.Token)
|
||||
switch claims := user.Claims.(type) {
|
||||
case jwt.MapClaims:
|
||||
assert.Equal(t, claims["name"], "John Doe")
|
||||
case *jwtCustomClaims:
|
||||
assert.Equal(t, claims.Name, "John Doe")
|
||||
assert.Equal(t, claims.Admin, true)
|
||||
default:
|
||||
panic("unexpected type of claims")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -420,7 +295,7 @@ func TestJWTConfig_skipper(t *testing.T) {
|
||||
Skipper: func(context echo.Context) bool {
|
||||
return true // skip everything
|
||||
},
|
||||
SigningKey: []byte("secret"),
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
|
||||
}))
|
||||
|
||||
isCalled := false
|
||||
@ -448,11 +323,11 @@ func TestJWTConfig_BeforeFunc(t *testing.T) {
|
||||
BeforeFunc: func(context echo.Context) {
|
||||
isCalled = true
|
||||
},
|
||||
SigningKey: []byte("secret"),
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
@ -469,18 +344,8 @@ func TestJWTConfig_extractorErrorHandling(t *testing.T) {
|
||||
{
|
||||
name: "ok, ErrorHandler is executed",
|
||||
given: JWTConfig{
|
||||
SigningKey: []byte("secret"),
|
||||
ErrorHandler: func(err error) error {
|
||||
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
|
||||
},
|
||||
},
|
||||
expectStatusCode: http.StatusTeapot,
|
||||
},
|
||||
{
|
||||
name: "ok, ErrorHandlerWithContext is executed",
|
||||
given: JWTConfig{
|
||||
SigningKey: []byte("secret"),
|
||||
ErrorHandlerWithContext: func(err error, context echo.Context) error {
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
|
||||
ErrorHandler: func(c echo.Context, err error) error {
|
||||
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
|
||||
},
|
||||
},
|
||||
@ -515,23 +380,13 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
|
||||
{
|
||||
name: "ok, ErrorHandler is executed",
|
||||
given: JWTConfig{
|
||||
SigningKey: []byte("secret"),
|
||||
ErrorHandler: func(err error) error {
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
|
||||
ErrorHandler: func(c echo.Context, err error) error {
|
||||
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error())
|
||||
},
|
||||
},
|
||||
expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n",
|
||||
},
|
||||
{
|
||||
name: "ok, ErrorHandlerWithContext is executed",
|
||||
given: JWTConfig{
|
||||
SigningKey: []byte("secret"),
|
||||
ErrorHandlerWithContext: func(err error, context echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error())
|
||||
},
|
||||
},
|
||||
expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@ -544,14 +399,14 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
|
||||
|
||||
config := tc.given
|
||||
parseTokenCalled := false
|
||||
config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) {
|
||||
config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) {
|
||||
parseTokenCalled = true
|
||||
return nil, errors.New("parsing failed")
|
||||
}
|
||||
e.Use(JWTWithConfig(config))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(res, req)
|
||||
@ -574,7 +429,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
|
||||
signingKey := []byte("secret")
|
||||
|
||||
config := JWTConfig{
|
||||
ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) {
|
||||
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
|
||||
keyFunc := func(t *jwt.Token) (interface{}, error) {
|
||||
if t.Method.Alg() != "HS256" {
|
||||
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
|
||||
@ -597,9 +452,161 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
|
||||
e.Use(JWTWithConfig(config))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusTeapot, res.Code)
|
||||
}
|
||||
|
||||
func TestMustJWTWithConfig_SuccessHandler(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
success := c.Get("success").(string)
|
||||
user := c.Get("user").(string)
|
||||
return c.String(http.StatusTeapot, fmt.Sprintf("%v:%v", success, user))
|
||||
})
|
||||
|
||||
mw, err := JWTConfig{
|
||||
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
|
||||
return auth, nil
|
||||
},
|
||||
SuccessHandler: func(c echo.Context) {
|
||||
c.Set("success", "yes")
|
||||
},
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
e.Use(mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add(echo.HeaderAuthorization, "Bearer valid_token_base64")
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, "yes:valid_token_base64", res.Body.String())
|
||||
assert.Equal(t, http.StatusTeapot, res.Code)
|
||||
}
|
||||
|
||||
func TestJWTWithConfig_CallNextOnNilErrorHandlerResult(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenCallNext bool
|
||||
givenErrorHandler JWTErrorHandlerWithContext
|
||||
givenTokenLookup string
|
||||
whenAuthHeaders []string
|
||||
whenCookies []string
|
||||
whenParseReturn string
|
||||
whenParseError error
|
||||
expectHandlerCalled bool
|
||||
expect string
|
||||
expectCode int
|
||||
}{
|
||||
{
|
||||
name: "ok, with valid JWT from auth header",
|
||||
givenCallNext: true,
|
||||
givenErrorHandler: func(c echo.Context, err error) error {
|
||||
return nil
|
||||
},
|
||||
whenAuthHeaders: []string{"Bearer valid_token_base64"},
|
||||
whenParseReturn: "valid_token",
|
||||
expectCode: http.StatusTeapot,
|
||||
expect: "valid_token",
|
||||
},
|
||||
{
|
||||
name: "ok, missing header, callNext and set public_token from error handler",
|
||||
givenCallNext: true,
|
||||
givenErrorHandler: func(c echo.Context, err error) error {
|
||||
if err != ErrJWTMissing {
|
||||
panic("must get ErrJWTMissing")
|
||||
}
|
||||
c.Set("user", "public_token")
|
||||
return nil
|
||||
},
|
||||
whenAuthHeaders: []string{}, // no JWT header
|
||||
expectCode: http.StatusTeapot,
|
||||
expect: "public_token",
|
||||
},
|
||||
{
|
||||
name: "ok, invalid token, callNext and set public_token from error handler",
|
||||
givenCallNext: true,
|
||||
givenErrorHandler: func(c echo.Context, err error) error {
|
||||
// this is probably not realistic usecase. on parse error you probably want to return error
|
||||
if err.Error() != "parser_error" {
|
||||
panic("must get parser_error")
|
||||
}
|
||||
c.Set("user", "public_token")
|
||||
return nil
|
||||
},
|
||||
whenAuthHeaders: []string{"Bearer invalid_header"},
|
||||
whenParseError: errors.New("parser_error"),
|
||||
expectCode: http.StatusTeapot,
|
||||
expect: "public_token",
|
||||
},
|
||||
{
|
||||
name: "nok, invalid token, return error from error handler",
|
||||
givenCallNext: true,
|
||||
givenErrorHandler: func(c echo.Context, err error) error {
|
||||
if err.Error() != "parser_error" {
|
||||
panic("must get parser_error")
|
||||
}
|
||||
return err
|
||||
},
|
||||
whenAuthHeaders: []string{"Bearer invalid_header"},
|
||||
whenParseError: errors.New("parser_error"),
|
||||
expectCode: http.StatusInternalServerError,
|
||||
expect: "{\"message\":\"Internal Server Error\"}\n",
|
||||
},
|
||||
{
|
||||
name: "nok, callNext but return error from error handler",
|
||||
givenCallNext: true,
|
||||
givenErrorHandler: func(c echo.Context, err error) error {
|
||||
return err
|
||||
},
|
||||
whenAuthHeaders: []string{}, // no JWT header
|
||||
expectCode: http.StatusUnauthorized,
|
||||
expect: "{\"message\":\"missing or malformed jwt\"}\n",
|
||||
},
|
||||
{
|
||||
name: "nok, callNext=false",
|
||||
givenCallNext: false,
|
||||
givenErrorHandler: func(c echo.Context, err error) error {
|
||||
return err
|
||||
},
|
||||
whenAuthHeaders: []string{}, // no JWT header
|
||||
expectCode: http.StatusUnauthorized,
|
||||
expect: "{\"message\":\"missing or malformed jwt\"}\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
token := c.Get("user").(string)
|
||||
return c.String(http.StatusTeapot, token)
|
||||
})
|
||||
|
||||
mw, err := JWTConfig{
|
||||
TokenLookup: tc.givenTokenLookup,
|
||||
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
|
||||
return tc.whenParseReturn, tc.whenParseError
|
||||
},
|
||||
ErrorHandler: tc.givenErrorHandler,
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
e.Use(mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
for _, a := range tc.whenAuthHeaders {
|
||||
req.Header.Add(echo.HeaderAuthorization, a)
|
||||
}
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, tc.expect, res.Body.String())
|
||||
assert.Equal(t, tc.expectCode, res.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -3,58 +3,59 @@ package middleware
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type (
|
||||
// KeyAuthConfig defines the config for KeyAuth middleware.
|
||||
KeyAuthConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// KeyAuthConfig defines the config for KeyAuth middleware.
|
||||
type KeyAuthConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// KeyLookup is a string in the form of "<source>:<name>" that is used
|
||||
// to extract key from the request.
|
||||
// Optional. Default value "header:Authorization".
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "query:<name>"
|
||||
// - "form:<name>"
|
||||
// - "cookie:<name>"
|
||||
KeyLookup string `yaml:"key_lookup"`
|
||||
// KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
|
||||
// to extract key(s) from the request.
|
||||
// Optional. Default value "header:Authorization:Bearer ".
|
||||
// Possible values:
|
||||
// - "header:<name>:<value prefix>"
|
||||
// - "query:<name>"
|
||||
// - "param:<name>"
|
||||
// - "cookie:<name>"
|
||||
// - "form:<name>"
|
||||
// Multiple sources example:
|
||||
// - "header:Authorization:Bearer ,cookie:myowncookie"
|
||||
KeyLookup string
|
||||
|
||||
// AuthScheme to be used in the Authorization header.
|
||||
// Optional. Default value "Bearer".
|
||||
AuthScheme string
|
||||
// Validator is a function to validate key.
|
||||
// Required.
|
||||
Validator KeyAuthValidator
|
||||
|
||||
// Validator is a function to validate key.
|
||||
// Required.
|
||||
Validator KeyAuthValidator
|
||||
// ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
|
||||
// function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
|
||||
// It may be used to define a custom error.
|
||||
//
|
||||
// Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
|
||||
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
|
||||
// In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain.
|
||||
ErrorHandler KeyAuthErrorHandler
|
||||
}
|
||||
|
||||
// ErrorHandler defines a function which is executed for an invalid key.
|
||||
// It may be used to define a custom error.
|
||||
ErrorHandler KeyAuthErrorHandler
|
||||
}
|
||||
// KeyAuthValidator defines a function to validate KeyAuth credentials.
|
||||
type KeyAuthValidator func(c echo.Context, key string, keyType ExtractorType) (bool, error)
|
||||
|
||||
// KeyAuthValidator defines a function to validate KeyAuth credentials.
|
||||
KeyAuthValidator func(string, echo.Context) (bool, error)
|
||||
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
|
||||
type KeyAuthErrorHandler func(c echo.Context, err error) error
|
||||
|
||||
keyExtractor func(echo.Context) (string, error)
|
||||
// ErrKeyMissing denotes an error raised when key value could not be extracted from request
|
||||
var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
|
||||
|
||||
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
|
||||
KeyAuthErrorHandler func(error, echo.Context) error
|
||||
)
|
||||
// ErrInvalidKey denotes an error raised when key value is invalid by validator
|
||||
var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
|
||||
|
||||
var (
|
||||
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
|
||||
DefaultKeyAuthConfig = KeyAuthConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
KeyLookup: "header:" + echo.HeaderAuthorization,
|
||||
AuthScheme: "Bearer",
|
||||
}
|
||||
)
|
||||
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
|
||||
var DefaultKeyAuthConfig = KeyAuthConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
|
||||
}
|
||||
|
||||
// KeyAuth returns an KeyAuth middleware.
|
||||
//
|
||||
@ -67,34 +68,32 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
|
||||
return KeyAuthWithConfig(c)
|
||||
}
|
||||
|
||||
// KeyAuthWithConfig returns an KeyAuth middleware with config.
|
||||
// See `KeyAuth()`.
|
||||
// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid.
|
||||
//
|
||||
// For first valid key it calls the next handler.
|
||||
// For invalid key, it sends "401 - Unauthorized" response.
|
||||
// For missing key, it sends "400 - Bad Request" response.
|
||||
func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration
|
||||
func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultKeyAuthConfig.Skipper
|
||||
}
|
||||
// Defaults
|
||||
if config.AuthScheme == "" {
|
||||
config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
|
||||
}
|
||||
if config.KeyLookup == "" {
|
||||
config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
|
||||
}
|
||||
if config.Validator == nil {
|
||||
panic("echo: key-auth middleware requires a validator function")
|
||||
return nil, errors.New("echo key-auth middleware requires a validator function")
|
||||
}
|
||||
|
||||
// Initialize
|
||||
parts := strings.Split(config.KeyLookup, ":")
|
||||
extractor := keyFromHeader(parts[1], config.AuthScheme)
|
||||
switch parts[0] {
|
||||
case "query":
|
||||
extractor = keyFromQuery(parts[1])
|
||||
case "form":
|
||||
extractor = keyFromForm(parts[1])
|
||||
case "cookie":
|
||||
extractor = keyFromCookie(parts[1])
|
||||
extractors, err := createExtractors(config.KeyLookup)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", err)
|
||||
}
|
||||
if len(extractors) == 0 {
|
||||
return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string")
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@ -103,79 +102,50 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Extract and verify key
|
||||
key, err := extractor(c)
|
||||
if err != nil {
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(err, c)
|
||||
var lastExtractorErr error
|
||||
var lastValidatorErr error
|
||||
for _, extractor := range extractors {
|
||||
keys, keyType, extrErr := extractor(c)
|
||||
if extrErr != nil {
|
||||
lastExtractorErr = extrErr
|
||||
continue
|
||||
}
|
||||
for _, key := range keys {
|
||||
valid, err := config.Validator(c, key, keyType)
|
||||
if err != nil {
|
||||
lastValidatorErr = err
|
||||
continue
|
||||
}
|
||||
if !valid {
|
||||
lastValidatorErr = ErrInvalidKey
|
||||
continue
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
valid, err := config.Validator(key, c)
|
||||
if err != nil {
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(err, c)
|
||||
|
||||
// prioritize validator errors over extracting errors
|
||||
err := lastValidatorErr
|
||||
if err == nil {
|
||||
err = lastExtractorErr
|
||||
}
|
||||
if config.ErrorHandler != nil {
|
||||
// Allow error handler to swallow the error and continue handler chain execution
|
||||
// Useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
|
||||
// In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain
|
||||
if handledErr := config.ErrorHandler(c, err); handledErr != nil {
|
||||
return handledErr
|
||||
}
|
||||
return &echo.HTTPError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "invalid key",
|
||||
Internal: err,
|
||||
}
|
||||
} else if valid {
|
||||
return next(c)
|
||||
}
|
||||
return echo.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromHeader returns a `keyExtractor` that extracts key from the request header.
|
||||
func keyFromHeader(header string, authScheme string) keyExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
auth := c.Request().Header.Get(header)
|
||||
if auth == "" {
|
||||
return "", errors.New("missing key in request header")
|
||||
}
|
||||
if header == echo.HeaderAuthorization {
|
||||
l := len(authScheme)
|
||||
if len(auth) > l+1 && auth[:l] == authScheme {
|
||||
return auth[l+1:], nil
|
||||
if err == ErrExtractionValueMissing {
|
||||
return ErrKeyMissing // do not wrap extractor errors
|
||||
}
|
||||
return &echo.HTTPError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "Unauthorized",
|
||||
Internal: err,
|
||||
}
|
||||
return "", errors.New("invalid key in the request header")
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
|
||||
func keyFromQuery(param string) keyExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
key := c.QueryParam(param)
|
||||
if key == "" {
|
||||
return "", errors.New("missing key in the query string")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromForm returns a `keyExtractor` that extracts key from the form.
|
||||
func keyFromForm(param string) keyExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
key := c.FormValue(param)
|
||||
if key == "" {
|
||||
return "", errors.New("missing key in the form")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromCookie returns a `keyExtractor` that extracts key from the form.
|
||||
func keyFromCookie(cookieName string) keyExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
key, err := c.Cookie(cookieName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("missing key in cookies: %w", err)
|
||||
}
|
||||
return key.Value, nil
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testKeyValidator(key string, c echo.Context) (bool, error) {
|
||||
func testKeyValidator(c echo.Context, key string, keyType ExtractorType) (bool, error) {
|
||||
switch key {
|
||||
case "valid-key":
|
||||
return true, nil
|
||||
@ -28,7 +28,7 @@ func TestKeyAuth(t *testing.T) {
|
||||
handlerCalled = true
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
middlewareChain := KeyAuth(testKeyValidator)(handler)
|
||||
middlewareChain := KeyAuthWithConfig(KeyAuthConfig{Validator: testKeyValidator})(handler)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=401, message=Unauthorized",
|
||||
expectError: "code=401, message=Unauthorized, internal=code=401, message=invalid key",
|
||||
},
|
||||
{
|
||||
name: "nok, defaults, invalid scheme in header",
|
||||
@ -84,13 +84,13 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=invalid key in the request header",
|
||||
expectError: "code=401, message=missing key",
|
||||
},
|
||||
{
|
||||
name: "nok, defaults, missing header",
|
||||
givenRequest: func(req *http.Request) {},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=missing key in request header",
|
||||
expectError: "code=401, message=missing key",
|
||||
},
|
||||
{
|
||||
name: "ok, custom key lookup, header",
|
||||
@ -110,7 +110,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
conf.KeyLookup = "header:API-Key"
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=missing key in request header",
|
||||
expectError: "code=401, message=missing key",
|
||||
},
|
||||
{
|
||||
name: "ok, custom key lookup, query",
|
||||
@ -130,7 +130,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
conf.KeyLookup = "query:key"
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=missing key in the query string",
|
||||
expectError: "code=401, message=missing key",
|
||||
},
|
||||
{
|
||||
name: "ok, custom key lookup, form",
|
||||
@ -155,7 +155,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
conf.KeyLookup = "form:key"
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=missing key in the form",
|
||||
expectError: "code=401, message=missing key",
|
||||
},
|
||||
{
|
||||
name: "ok, custom key lookup, cookie",
|
||||
@ -179,20 +179,20 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
conf.KeyLookup = "cookie:key"
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=missing key in cookies: http: named cookie not present",
|
||||
expectError: "code=401, message=missing key",
|
||||
},
|
||||
{
|
||||
name: "nok, custom errorHandler, error from extractor",
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.KeyLookup = "header:token"
|
||||
conf.ErrorHandler = func(err error, context echo.Context) error {
|
||||
conf.ErrorHandler = func(c echo.Context, err error) error {
|
||||
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
|
||||
httpError.Internal = err
|
||||
return httpError
|
||||
}
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=418, message=custom, internal=missing key in request header",
|
||||
expectError: "code=418, message=custom, internal=code=400, message=missing or malformed value",
|
||||
},
|
||||
{
|
||||
name: "nok, custom errorHandler, error from validator",
|
||||
@ -200,7 +200,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.ErrorHandler = func(err error, context echo.Context) error {
|
||||
conf.ErrorHandler = func(c echo.Context, err error) error {
|
||||
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
|
||||
httpError.Internal = err
|
||||
return httpError
|
||||
@ -216,7 +216,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=401, message=invalid key, internal=some user defined error",
|
||||
expectError: "code=401, message=Unauthorized, internal=some user defined error",
|
||||
},
|
||||
}
|
||||
|
||||
@ -257,3 +257,96 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyAuthWithConfig_errors(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenConfig KeyAuthConfig
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, no error",
|
||||
whenConfig: KeyAuthConfig{
|
||||
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, missing validator func",
|
||||
whenConfig: KeyAuthConfig{
|
||||
Validator: nil,
|
||||
},
|
||||
expectError: "echo key-auth middleware requires a validator function",
|
||||
},
|
||||
{
|
||||
name: "ok, extractor source can not be split",
|
||||
whenConfig: KeyAuthConfig{
|
||||
KeyLookup: "nope",
|
||||
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
},
|
||||
expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope",
|
||||
},
|
||||
{
|
||||
name: "ok, no extractors",
|
||||
whenConfig: KeyAuthConfig{
|
||||
KeyLookup: "nope:nope",
|
||||
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
},
|
||||
expectError: "echo key-auth middleware could not create extractors from KeyLookup string",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mw, err := tc.whenConfig.ToMiddleware()
|
||||
if tc.expectError != "" {
|
||||
assert.Nil(t, mw)
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NotNil(t, mw)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMustKeyAuthWithConfig_panic(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
KeyAuthWithConfig(KeyAuthConfig{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) {
|
||||
handlerCalled := false
|
||||
var authValue string
|
||||
handler := func(c echo.Context) error {
|
||||
handlerCalled = true
|
||||
authValue = c.Get("auth").(string)
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
middlewareChain := KeyAuthWithConfig(KeyAuthConfig{
|
||||
Validator: testKeyValidator,
|
||||
ErrorHandler: func(c echo.Context, err error) error {
|
||||
// could check error to decide if we can swallow the error
|
||||
c.Set("auth", "public")
|
||||
return nil
|
||||
},
|
||||
})(handler)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
// no auth header this time
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := middlewareChain(c)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, handlerCalled)
|
||||
assert.Equal(t, "public", authValue)
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -10,81 +11,78 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/color"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
type (
|
||||
// LoggerConfig defines the config for Logger middleware.
|
||||
LoggerConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// LoggerConfig defines the config for Logger middleware.
|
||||
type LoggerConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Tags to construct the logger format.
|
||||
//
|
||||
// - time_unix
|
||||
// - time_unix_nano
|
||||
// - time_rfc3339
|
||||
// - time_rfc3339_nano
|
||||
// - time_custom
|
||||
// - id (Request ID)
|
||||
// - remote_ip
|
||||
// - uri
|
||||
// - host
|
||||
// - method
|
||||
// - path
|
||||
// - protocol
|
||||
// - referer
|
||||
// - user_agent
|
||||
// - status
|
||||
// - error
|
||||
// - latency (In nanoseconds)
|
||||
// - latency_human (Human readable)
|
||||
// - bytes_in (Bytes received)
|
||||
// - bytes_out (Bytes sent)
|
||||
// - header:<NAME>
|
||||
// - query:<NAME>
|
||||
// - form:<NAME>
|
||||
//
|
||||
// Example "${remote_ip} ${status}"
|
||||
//
|
||||
// Optional. Default value DefaultLoggerConfig.Format.
|
||||
Format string `yaml:"format"`
|
||||
// Tags to construct the logger format.
|
||||
//
|
||||
// - time_unix
|
||||
// - time_unix_nano
|
||||
// - time_rfc3339
|
||||
// - time_rfc3339_nano
|
||||
// - time_custom
|
||||
// - id (Request ID)
|
||||
// - remote_ip
|
||||
// - uri
|
||||
// - host
|
||||
// - method
|
||||
// - path
|
||||
// - protocol
|
||||
// - referer
|
||||
// - user_agent
|
||||
// - status
|
||||
// - error
|
||||
// - latency (In nanoseconds)
|
||||
// - latency_human (Human readable)
|
||||
// - bytes_in (Bytes received)
|
||||
// - bytes_out (Bytes sent)
|
||||
// - header:<NAME>
|
||||
// - query:<NAME>
|
||||
// - form:<NAME>
|
||||
//
|
||||
// Example "${remote_ip} ${status}"
|
||||
//
|
||||
// Optional. Default value DefaultLoggerConfig.Format.
|
||||
Format string
|
||||
|
||||
// Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
|
||||
CustomTimeFormat string `yaml:"custom_time_format"`
|
||||
// Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
|
||||
CustomTimeFormat string
|
||||
|
||||
// Output is a writer where logs in JSON format are written.
|
||||
// Optional. Default value os.Stdout.
|
||||
Output io.Writer
|
||||
// Output is a writer where logs in JSON format are written.
|
||||
// Optional. Default destination `echo.Logger.Infof()`
|
||||
Output io.Writer
|
||||
|
||||
template *fasttemplate.Template
|
||||
colorer *color.Color
|
||||
pool *sync.Pool
|
||||
}
|
||||
)
|
||||
template *fasttemplate.Template
|
||||
pool *sync.Pool
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultLoggerConfig is the default Logger middleware config.
|
||||
DefaultLoggerConfig = LoggerConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` +
|
||||
`"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` +
|
||||
`"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` +
|
||||
`,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n",
|
||||
CustomTimeFormat: "2006-01-02 15:04:05.00000",
|
||||
colorer: color.New(),
|
||||
}
|
||||
)
|
||||
// DefaultLoggerConfig is the default Logger middleware config.
|
||||
var DefaultLoggerConfig = LoggerConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` +
|
||||
`"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` +
|
||||
`"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` +
|
||||
`,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n",
|
||||
CustomTimeFormat: "2006-01-02 15:04:05.00000",
|
||||
}
|
||||
|
||||
// Logger returns a middleware that logs HTTP requests.
|
||||
func Logger() echo.MiddlewareFunc {
|
||||
return LoggerWithConfig(DefaultLoggerConfig)
|
||||
}
|
||||
|
||||
// LoggerWithConfig returns a Logger middleware with config.
|
||||
// See: `Logger()`.
|
||||
// LoggerWithConfig returns a Logger middleware with config or panics on invalid configuration.
|
||||
func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts LoggerConfig to middleware or returns an error for invalid configuration
|
||||
func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
// Defaults
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultLoggerConfig.Skipper
|
||||
@ -92,13 +90,8 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
|
||||
if config.Format == "" {
|
||||
config.Format = DefaultLoggerConfig.Format
|
||||
}
|
||||
if config.Output == nil {
|
||||
config.Output = DefaultLoggerConfig.Output
|
||||
}
|
||||
|
||||
config.template = fasttemplate.New(config.Format, "${", "}")
|
||||
config.colorer = color.New()
|
||||
config.colorer.SetOutput(config.Output)
|
||||
config.pool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return bytes.NewBuffer(make([]byte, 256))
|
||||
@ -106,23 +99,23 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) (err error) {
|
||||
return func(c echo.Context) error {
|
||||
if config.Skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
req := c.Request()
|
||||
res := c.Response()
|
||||
|
||||
start := time.Now()
|
||||
if err = next(c); err != nil {
|
||||
c.Error(err)
|
||||
}
|
||||
err := next(c)
|
||||
stop := time.Now()
|
||||
|
||||
buf := config.pool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
defer config.pool.Put(buf)
|
||||
|
||||
if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
|
||||
_, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
|
||||
switch tag {
|
||||
case "time_unix":
|
||||
return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10))
|
||||
@ -161,17 +154,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
|
||||
case "user_agent":
|
||||
return buf.WriteString(req.UserAgent())
|
||||
case "status":
|
||||
n := res.Status
|
||||
s := config.colorer.Green(n)
|
||||
switch {
|
||||
case n >= 500:
|
||||
s = config.colorer.Red(n)
|
||||
case n >= 400:
|
||||
s = config.colorer.Yellow(n)
|
||||
case n >= 300:
|
||||
s = config.colorer.Cyan(n)
|
||||
status := res.Status
|
||||
if err != nil {
|
||||
if httpErr, ok := err.(*echo.HTTPError); ok {
|
||||
status = httpErr.Code
|
||||
}
|
||||
}
|
||||
return buf.WriteString(s)
|
||||
return buf.WriteString(strconv.Itoa(status))
|
||||
case "error":
|
||||
if err != nil {
|
||||
// Error may contain invalid JSON e.g. `"`
|
||||
@ -201,23 +190,31 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
|
||||
case strings.HasPrefix(tag, "form:"):
|
||||
return buf.Write([]byte(c.FormValue(tag[5:])))
|
||||
case strings.HasPrefix(tag, "cookie:"):
|
||||
cookie, err := c.Cookie(tag[7:])
|
||||
if err == nil {
|
||||
cookie, cookieErr := c.Cookie(tag[7:])
|
||||
if cookieErr == nil {
|
||||
return buf.Write([]byte(cookie.Value))
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}); err != nil {
|
||||
return
|
||||
})
|
||||
if tmplErr != nil {
|
||||
if err != nil {
|
||||
return fmt.Errorf("error in middleware chain and also failed to create log from template: %v: %w", tmplErr, err)
|
||||
}
|
||||
return fmt.Errorf("failed to create log from template: %w", tmplErr)
|
||||
}
|
||||
|
||||
if config.Output == nil {
|
||||
_, err = c.Logger().Output().Write(buf.Bytes())
|
||||
return
|
||||
if config.Output != nil {
|
||||
if _, lErr := config.Output.Write(buf.Bytes()); lErr != nil {
|
||||
return lErr
|
||||
}
|
||||
} else {
|
||||
if _, lErr := c.Echo().Logger.Write(buf.Bytes()); lErr != nil {
|
||||
return lErr
|
||||
}
|
||||
}
|
||||
_, err = config.Output.Write(buf.Bytes())
|
||||
return
|
||||
return err
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ func TestLoggerIPAddress(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
buf := new(bytes.Buffer)
|
||||
e.Logger.SetOutput(buf)
|
||||
e.Logger = &testLogger{output: buf}
|
||||
ip := "127.0.0.1"
|
||||
h := Logger()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
|
@ -6,28 +6,24 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
// MethodOverrideConfig defines the config for MethodOverride middleware.
|
||||
MethodOverrideConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// MethodOverrideConfig defines the config for MethodOverride middleware.
|
||||
type MethodOverrideConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Getter is a function that gets overridden method from the request.
|
||||
// Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
|
||||
Getter MethodOverrideGetter
|
||||
}
|
||||
// Getter is a function that gets overridden method from the request.
|
||||
// Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
|
||||
Getter MethodOverrideGetter
|
||||
}
|
||||
|
||||
// MethodOverrideGetter is a function that gets overridden method from the request
|
||||
MethodOverrideGetter func(echo.Context) string
|
||||
)
|
||||
// MethodOverrideGetter is a function that gets overridden method from the request
|
||||
type MethodOverrideGetter func(echo.Context) string
|
||||
|
||||
var (
|
||||
// DefaultMethodOverrideConfig is the default MethodOverride middleware config.
|
||||
DefaultMethodOverrideConfig = MethodOverrideConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
|
||||
}
|
||||
)
|
||||
// DefaultMethodOverrideConfig is the default MethodOverride middleware config.
|
||||
var DefaultMethodOverrideConfig = MethodOverrideConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
|
||||
}
|
||||
|
||||
// MethodOverride returns a MethodOverride middleware.
|
||||
// MethodOverride middleware checks for the overridden method from the request and
|
||||
@ -38,9 +34,13 @@ func MethodOverride() echo.MiddlewareFunc {
|
||||
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
|
||||
}
|
||||
|
||||
// MethodOverrideWithConfig returns a MethodOverride middleware with config.
|
||||
// See: `MethodOverride()`.
|
||||
// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration.
|
||||
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration
|
||||
func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
// Defaults
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultMethodOverrideConfig.Skipper
|
||||
@ -64,7 +64,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
|
||||
|
@ -22,28 +22,70 @@ func TestMethodOverride(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
|
||||
c := e.NewContext(req, rec)
|
||||
m(h)(c)
|
||||
|
||||
err := m(h)(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.MethodDelete, req.Method)
|
||||
|
||||
}
|
||||
|
||||
func TestMethodOverride_formParam(t *testing.T) {
|
||||
e := echo.New()
|
||||
m := MethodOverride()
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
// Override with form parameter
|
||||
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")})
|
||||
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
|
||||
rec = httptest.NewRecorder()
|
||||
m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
|
||||
rec := httptest.NewRecorder()
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
c = e.NewContext(req, rec)
|
||||
m(h)(c)
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err = m(h)(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.MethodDelete, req.Method)
|
||||
}
|
||||
|
||||
func TestMethodOverride_queryParam(t *testing.T) {
|
||||
e := echo.New()
|
||||
m := MethodOverride()
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
// Override with query parameter
|
||||
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")})
|
||||
req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
m(h)(c)
|
||||
m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err = m(h)(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.MethodDelete, req.Method)
|
||||
}
|
||||
|
||||
func TestMethodOverride_ignoreGet(t *testing.T) {
|
||||
e := echo.New()
|
||||
m := MethodOverride()
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
// Ignore `GET`
|
||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := m(h)(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.MethodGet, req.Method)
|
||||
}
|
||||
|
@ -9,14 +9,11 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
// Skipper defines a function to skip middleware. Returning true skips processing
|
||||
// the middleware.
|
||||
Skipper func(echo.Context) bool
|
||||
// Skipper defines a function to skip middleware. Returning true skips processing the middleware.
|
||||
type Skipper func(c echo.Context) bool
|
||||
|
||||
// BeforeFunc defines a function which is executed just before the middleware.
|
||||
BeforeFunc func(echo.Context)
|
||||
)
|
||||
// BeforeFunc defines a function which is executed just before the middleware.
|
||||
type BeforeFunc func(c echo.Context)
|
||||
|
||||
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
|
||||
groups := pattern.FindAllStringSubmatch(input, -1)
|
||||
@ -87,3 +84,11 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error
|
||||
func DefaultSkipper(echo.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc {
|
||||
mw, err := config.ToMiddleware()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return mw
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
@ -20,85 +21,81 @@ import (
|
||||
|
||||
// TODO: Handle TLS proxy
|
||||
|
||||
type (
|
||||
// ProxyConfig defines the config for Proxy middleware.
|
||||
ProxyConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// ProxyConfig defines the config for Proxy middleware.
|
||||
type ProxyConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Balancer defines a load balancing technique.
|
||||
// Required.
|
||||
Balancer ProxyBalancer
|
||||
// Balancer defines a load balancing technique.
|
||||
// Required.
|
||||
Balancer ProxyBalancer
|
||||
|
||||
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
|
||||
// retrieved by index e.g. $1, $2 and so on.
|
||||
// Examples:
|
||||
// "/old": "/new",
|
||||
// "/api/*": "/$1",
|
||||
// "/js/*": "/public/javascripts/$1",
|
||||
// "/users/*/orders/*": "/user/$1/order/$2",
|
||||
Rewrite map[string]string
|
||||
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
|
||||
// retrieved by index e.g. $1, $2 and so on.
|
||||
// Examples:
|
||||
// "/old": "/new",
|
||||
// "/api/*": "/$1",
|
||||
// "/js/*": "/public/javascripts/$1",
|
||||
// "/users/*/orders/*": "/user/$1/order/$2",
|
||||
Rewrite map[string]string
|
||||
|
||||
// RegexRewrite defines rewrite rules using regexp.Rexexp with captures
|
||||
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
|
||||
// Example:
|
||||
// "^/old/[0.9]+/": "/new",
|
||||
// "^/api/.+?/(.*)": "/v2/$1",
|
||||
RegexRewrite map[*regexp.Regexp]string
|
||||
// RegexRewrite defines rewrite rules using regexp.Rexexp with captures
|
||||
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
|
||||
// Example:
|
||||
// "^/old/[0.9]+/": "/new",
|
||||
// "^/api/.+?/(.*)": "/v2/$1",
|
||||
RegexRewrite map[*regexp.Regexp]string
|
||||
|
||||
// Context key to store selected ProxyTarget into context.
|
||||
// Optional. Default value "target".
|
||||
ContextKey string
|
||||
// Context key to store selected ProxyTarget into context.
|
||||
// Optional. Default value "target".
|
||||
ContextKey string
|
||||
|
||||
// To customize the transport to remote.
|
||||
// Examples: If custom TLS certificates are required.
|
||||
Transport http.RoundTripper
|
||||
// To customize the transport to remote.
|
||||
// Examples: If custom TLS certificates are required.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// ModifyResponse defines function to modify response from ProxyTarget.
|
||||
ModifyResponse func(*http.Response) error
|
||||
}
|
||||
// ModifyResponse defines function to modify response from ProxyTarget.
|
||||
ModifyResponse func(*http.Response) error
|
||||
}
|
||||
|
||||
// ProxyTarget defines the upstream target.
|
||||
ProxyTarget struct {
|
||||
Name string
|
||||
URL *url.URL
|
||||
Meta echo.Map
|
||||
}
|
||||
// ProxyTarget defines the upstream target.
|
||||
type ProxyTarget struct {
|
||||
Name string
|
||||
URL *url.URL
|
||||
Meta echo.Map
|
||||
}
|
||||
|
||||
// ProxyBalancer defines an interface to implement a load balancing technique.
|
||||
ProxyBalancer interface {
|
||||
AddTarget(*ProxyTarget) bool
|
||||
RemoveTarget(string) bool
|
||||
Next(echo.Context) *ProxyTarget
|
||||
}
|
||||
// ProxyBalancer defines an interface to implement a load balancing technique.
|
||||
type ProxyBalancer interface {
|
||||
AddTarget(*ProxyTarget) bool
|
||||
RemoveTarget(string) bool
|
||||
Next(echo.Context) *ProxyTarget
|
||||
}
|
||||
|
||||
commonBalancer struct {
|
||||
targets []*ProxyTarget
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
type commonBalancer struct {
|
||||
targets []*ProxyTarget
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// RandomBalancer implements a random load balancing technique.
|
||||
randomBalancer struct {
|
||||
*commonBalancer
|
||||
random *rand.Rand
|
||||
}
|
||||
// RandomBalancer implements a random load balancing technique.
|
||||
type randomBalancer struct {
|
||||
*commonBalancer
|
||||
random *rand.Rand
|
||||
}
|
||||
|
||||
// RoundRobinBalancer implements a round-robin load balancing technique.
|
||||
roundRobinBalancer struct {
|
||||
*commonBalancer
|
||||
i uint32
|
||||
}
|
||||
)
|
||||
// RoundRobinBalancer implements a round-robin load balancing technique.
|
||||
type roundRobinBalancer struct {
|
||||
*commonBalancer
|
||||
i uint32
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultProxyConfig is the default Proxy middleware config.
|
||||
DefaultProxyConfig = ProxyConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
ContextKey: "target",
|
||||
}
|
||||
)
|
||||
// DefaultProxyConfig is the default Proxy middleware config.
|
||||
var DefaultProxyConfig = ProxyConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
ContextKey: "target",
|
||||
}
|
||||
|
||||
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
|
||||
func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
in, _, err := c.Response().Hijack()
|
||||
if err != nil {
|
||||
@ -203,15 +200,23 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
|
||||
return ProxyWithConfig(c)
|
||||
}
|
||||
|
||||
// ProxyWithConfig returns a Proxy middleware with config.
|
||||
// See: `Proxy()`
|
||||
// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid.
|
||||
//
|
||||
// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
|
||||
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration
|
||||
func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultProxyConfig.Skipper
|
||||
}
|
||||
if config.ContextKey == "" {
|
||||
config.ContextKey = DefaultProxyConfig.ContextKey
|
||||
}
|
||||
if config.Balancer == nil {
|
||||
panic("echo: proxy middleware requires balancer")
|
||||
return nil, errors.New("echo proxy middleware requires balancer")
|
||||
}
|
||||
|
||||
if config.Rewrite != nil {
|
||||
@ -254,10 +259,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
||||
// Proxy
|
||||
switch {
|
||||
case c.IsWebSocket():
|
||||
proxyRaw(tgt, c).ServeHTTP(res, req)
|
||||
proxyRaw(c, tgt).ServeHTTP(res, req)
|
||||
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
|
||||
default:
|
||||
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
|
||||
proxyHTTP(c, tgt, config).ServeHTTP(res, req)
|
||||
}
|
||||
if e, ok := c.Get("_error").(error); ok {
|
||||
err = e
|
||||
@ -265,7 +270,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StatusCodeContextCanceled is a custom HTTP status code for situations
|
||||
@ -275,7 +280,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
|
||||
// 499 too instead of the more problematic 5xx, which does not allow to detect this situation
|
||||
const StatusCodeContextCanceled = 499
|
||||
|
||||
func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
|
||||
func proxyHTTP(c echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler {
|
||||
proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
|
||||
proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
|
||||
desc := tgt.URL.String()
|
||||
|
@ -55,7 +55,7 @@ func TestProxy(t *testing.T) {
|
||||
|
||||
// Random
|
||||
e := echo.New()
|
||||
e.Use(Proxy(rb))
|
||||
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
@ -77,7 +77,7 @@ func TestProxy(t *testing.T) {
|
||||
// Round-robin
|
||||
rrb := NewRoundRobinBalancer(targets)
|
||||
e = echo.New()
|
||||
e.Use(Proxy(rrb))
|
||||
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
@ -113,15 +113,20 @@ func TestProxy(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
rrb1 := NewRoundRobinBalancer(targets)
|
||||
|
||||
e = echo.New()
|
||||
e.Use(contextObserver)
|
||||
e.Use(Proxy(rrb1))
|
||||
e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)}))
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
}
|
||||
|
||||
func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
ProxyWithConfig(ProxyConfig{Balancer: nil})
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyRealIPHeader(t *testing.T) {
|
||||
// Setup
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
@ -129,7 +134,7 @@ func TestProxyRealIPHeader(t *testing.T) {
|
||||
url, _ := url.Parse(upstream.URL)
|
||||
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
|
||||
e := echo.New()
|
||||
e.Use(Proxy(rrb))
|
||||
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@ -334,7 +339,7 @@ func TestProxyError(t *testing.T) {
|
||||
|
||||
// Random
|
||||
e := echo.New()
|
||||
e.Use(Proxy(rb))
|
||||
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@ -362,7 +367,7 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
|
||||
rb := NewRandomBalancer(nil)
|
||||
assert.True(t, rb.AddTarget(target))
|
||||
e := echo.New()
|
||||
e.Use(Proxy(rb))
|
||||
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@ -9,39 +10,33 @@ import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type (
|
||||
// RateLimiterStore is the interface to be implemented by custom stores.
|
||||
RateLimiterStore interface {
|
||||
// Stores for the rate limiter have to implement the Allow method
|
||||
Allow(identifier string) (bool, error)
|
||||
}
|
||||
)
|
||||
// RateLimiterStore is the interface to be implemented by custom stores.
|
||||
type RateLimiterStore interface {
|
||||
Allow(identifier string) (bool, error)
|
||||
}
|
||||
|
||||
type (
|
||||
// RateLimiterConfig defines the configuration for the rate limiter
|
||||
RateLimiterConfig struct {
|
||||
Skipper Skipper
|
||||
BeforeFunc BeforeFunc
|
||||
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor
|
||||
IdentifierExtractor Extractor
|
||||
// Store defines a store for the rate limiter
|
||||
Store RateLimiterStore
|
||||
// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
|
||||
ErrorHandler func(context echo.Context, err error) error
|
||||
// DenyHandler provides a handler to be called when RateLimiter denies access
|
||||
DenyHandler func(context echo.Context, identifier string, err error) error
|
||||
}
|
||||
// Extractor is used to extract data from echo.Context
|
||||
Extractor func(context echo.Context) (string, error)
|
||||
)
|
||||
// RateLimiterConfig defines the configuration for the rate limiter
|
||||
type RateLimiterConfig struct {
|
||||
Skipper Skipper
|
||||
BeforeFunc BeforeFunc
|
||||
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor
|
||||
IdentifierExtractor Extractor
|
||||
// Store defines a store for the rate limiter
|
||||
Store RateLimiterStore
|
||||
// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
|
||||
ErrorHandler func(context echo.Context, err error) error
|
||||
// DenyHandler provides a handler to be called when RateLimiter denies access
|
||||
DenyHandler func(context echo.Context, identifier string, err error) error
|
||||
}
|
||||
|
||||
// errors
|
||||
var (
|
||||
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
|
||||
ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
|
||||
// ErrExtractorError denotes an error raised when extractor function is unsuccessful
|
||||
ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
|
||||
)
|
||||
// Extractor is used to extract data from echo.Context
|
||||
type Extractor func(context echo.Context) (string, error)
|
||||
|
||||
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
|
||||
var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
|
||||
|
||||
// ErrExtractorError denotes an error raised when extractor function is unsuccessful
|
||||
var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
|
||||
|
||||
// DefaultRateLimiterConfig defines default values for RateLimiterConfig
|
||||
var DefaultRateLimiterConfig = RateLimiterConfig{
|
||||
@ -111,6 +106,11 @@ RateLimiterWithConfig returns a rate limiting middleware
|
||||
}, middleware.RateLimiterWithConfig(config))
|
||||
*/
|
||||
func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration
|
||||
func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultRateLimiterConfig.Skipper
|
||||
}
|
||||
@ -124,7 +124,7 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
|
||||
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
|
||||
}
|
||||
if config.Store == nil {
|
||||
panic("Store configuration must be provided")
|
||||
return nil, errors.New("echo rate limiter store configuration must be provided")
|
||||
}
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
@ -137,35 +137,32 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
|
||||
|
||||
identifier, err := config.IdentifierExtractor(c)
|
||||
if err != nil {
|
||||
c.Error(config.ErrorHandler(c, err))
|
||||
return nil
|
||||
return config.ErrorHandler(c, err)
|
||||
}
|
||||
|
||||
if allow, err := config.Store.Allow(identifier); !allow {
|
||||
c.Error(config.DenyHandler(c, identifier, err))
|
||||
return nil
|
||||
if allow, allowErr := config.Store.Allow(identifier); !allow {
|
||||
return config.DenyHandler(c, identifier, allowErr)
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type (
|
||||
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
|
||||
RateLimiterMemoryStore struct {
|
||||
visitors map[string]*Visitor
|
||||
mutex sync.Mutex
|
||||
rate rate.Limit
|
||||
burst int
|
||||
expiresIn time.Duration
|
||||
lastCleanup time.Time
|
||||
}
|
||||
// Visitor signifies a unique user's limiter details
|
||||
Visitor struct {
|
||||
*rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
)
|
||||
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
|
||||
type RateLimiterMemoryStore struct {
|
||||
visitors map[string]*Visitor
|
||||
mutex sync.Mutex
|
||||
rate rate.Limit
|
||||
burst int
|
||||
expiresIn time.Duration
|
||||
lastCleanup time.Time
|
||||
}
|
||||
|
||||
// Visitor signifies a unique user's limiter details
|
||||
type Visitor struct {
|
||||
*rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
/*
|
||||
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/random"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
@ -25,19 +24,19 @@ func TestRateLimiter(t *testing.T) {
|
||||
|
||||
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
||||
|
||||
mw := RateLimiter(inMemoryStore)
|
||||
mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
|
||||
|
||||
testCases := []struct {
|
||||
id string
|
||||
code int
|
||||
id string
|
||||
expectErr string
|
||||
}{
|
||||
{"127.0.0.1", http.StatusOK},
|
||||
{"127.0.0.1", http.StatusOK},
|
||||
{"127.0.0.1", http.StatusOK},
|
||||
{"127.0.0.1", http.StatusTooManyRequests},
|
||||
{"127.0.0.1", http.StatusTooManyRequests},
|
||||
{"127.0.0.1", http.StatusTooManyRequests},
|
||||
{"127.0.0.1", http.StatusTooManyRequests},
|
||||
{id: "127.0.0.1"},
|
||||
{id: "127.0.0.1"},
|
||||
{id: "127.0.0.1"},
|
||||
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
|
||||
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
|
||||
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
|
||||
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@ -47,20 +46,25 @@ func TestRateLimiter(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
_ = mw(handler)(c)
|
||||
assert.Equal(t, tc.code, rec.Code)
|
||||
err := mw(handler)(c)
|
||||
if tc.expectErr != "" {
|
||||
assert.EqualError(t, err, tc.expectErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_panicBehaviour(t *testing.T) {
|
||||
func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) {
|
||||
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
||||
|
||||
assert.Panics(t, func() {
|
||||
RateLimiter(nil)
|
||||
RateLimiterWithConfig(RateLimiterConfig{})
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
RateLimiter(inMemoryStore)
|
||||
RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
|
||||
})
|
||||
}
|
||||
|
||||
@ -73,7 +77,7 @@ func TestRateLimiterWithConfig(t *testing.T) {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||
mw, err := RateLimiterConfig{
|
||||
IdentifierExtractor: func(c echo.Context) (string, error) {
|
||||
id := c.Request().Header.Get(echo.HeaderXRealIP)
|
||||
if id == "" {
|
||||
@ -88,7 +92,8 @@ func TestRateLimiterWithConfig(t *testing.T) {
|
||||
return ctx.JSON(http.StatusBadRequest, nil)
|
||||
},
|
||||
Store: inMemoryStore,
|
||||
})
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
id string
|
||||
@ -111,8 +116,9 @@ func TestRateLimiterWithConfig(t *testing.T) {
|
||||
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
_ = mw(handler)(c)
|
||||
err := mw(handler)(c)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.code, rec.Code)
|
||||
}
|
||||
}
|
||||
@ -126,7 +132,7 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||
mw, err := RateLimiterConfig{
|
||||
IdentifierExtractor: func(c echo.Context) (string, error) {
|
||||
id := c.Request().Header.Get(echo.HeaderXRealIP)
|
||||
if id == "" {
|
||||
@ -135,19 +141,20 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
|
||||
return id, nil
|
||||
},
|
||||
Store: inMemoryStore,
|
||||
})
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
id string
|
||||
code int
|
||||
id string
|
||||
expectErr string
|
||||
}{
|
||||
{"127.0.0.1", http.StatusOK},
|
||||
{"127.0.0.1", http.StatusOK},
|
||||
{"127.0.0.1", http.StatusOK},
|
||||
{"127.0.0.1", http.StatusTooManyRequests},
|
||||
{"", http.StatusForbidden},
|
||||
{"127.0.0.1", http.StatusTooManyRequests},
|
||||
{"127.0.0.1", http.StatusTooManyRequests},
|
||||
{id: "127.0.0.1"},
|
||||
{id: "127.0.0.1"},
|
||||
{id: "127.0.0.1"},
|
||||
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
|
||||
{expectErr: "code=403, message=error while extracting identifier, internal=invalid identifier"},
|
||||
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
|
||||
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@ -158,9 +165,13 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
|
||||
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
_ = mw(handler)(c)
|
||||
|
||||
assert.Equal(t, tc.code, rec.Code)
|
||||
err := mw(handler)(c)
|
||||
if tc.expectErr != "" {
|
||||
assert.EqualError(t, err, tc.expectErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
@ -174,21 +185,22 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||
mw, err := RateLimiterConfig{
|
||||
Store: inMemoryStore,
|
||||
})
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
id string
|
||||
code int
|
||||
id string
|
||||
expectErr string
|
||||
}{
|
||||
{"127.0.0.1", http.StatusOK},
|
||||
{"127.0.0.1", http.StatusOK},
|
||||
{"127.0.0.1", http.StatusOK},
|
||||
{"127.0.0.1", http.StatusTooManyRequests},
|
||||
{"127.0.0.1", http.StatusTooManyRequests},
|
||||
{"127.0.0.1", http.StatusTooManyRequests},
|
||||
{"127.0.0.1", http.StatusTooManyRequests},
|
||||
{id: "127.0.0.1"},
|
||||
{id: "127.0.0.1"},
|
||||
{id: "127.0.0.1"},
|
||||
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
|
||||
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
|
||||
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
|
||||
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@ -199,9 +211,13 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
|
||||
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
_ = mw(handler)(c)
|
||||
|
||||
assert.Equal(t, tc.code, rec.Code)
|
||||
err := mw(handler)(c)
|
||||
if tc.expectErr != "" {
|
||||
assert.EqualError(t, err, tc.expectErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -222,7 +238,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
|
||||
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||
mw, err := RateLimiterConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return true
|
||||
},
|
||||
@ -233,10 +249,12 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
|
||||
IdentifierExtractor: func(ctx echo.Context) (string, error) {
|
||||
return "127.0.0.1", nil
|
||||
},
|
||||
})
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_ = mw(handler)(c)
|
||||
err = mw(handler)(c)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, false, beforeFuncRan)
|
||||
}
|
||||
|
||||
@ -256,7 +274,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
|
||||
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||
mw, err := RateLimiterConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return false
|
||||
},
|
||||
@ -267,7 +285,8 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
|
||||
IdentifierExtractor: func(ctx echo.Context) (string, error) {
|
||||
return "127.0.0.1", nil
|
||||
},
|
||||
})
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_ = mw(handler)(c)
|
||||
|
||||
@ -291,7 +310,7 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
|
||||
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||
mw, err := RateLimiterConfig{
|
||||
BeforeFunc: func(c echo.Context) {
|
||||
beforeRan = true
|
||||
},
|
||||
@ -299,10 +318,12 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
|
||||
IdentifierExtractor: func(ctx echo.Context) (string, error) {
|
||||
return "127.0.0.1", nil
|
||||
},
|
||||
})
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_ = mw(handler)(c)
|
||||
err = mw(handler)(c)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, beforeRan)
|
||||
}
|
||||
|
||||
@ -413,7 +434,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) {
|
||||
func generateAddressList(count int) []string {
|
||||
addrs := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
addrs[i] = random.String(15)
|
||||
addrs[i] = randomString(15)
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
|
@ -5,44 +5,34 @@ import (
|
||||
"runtime"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/log"
|
||||
)
|
||||
|
||||
type (
|
||||
// RecoverConfig defines the config for Recover middleware.
|
||||
RecoverConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// RecoverConfig defines the config for Recover middleware.
|
||||
type RecoverConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Size of the stack to be printed.
|
||||
// Optional. Default value 4KB.
|
||||
StackSize int `yaml:"stack_size"`
|
||||
// Size of the stack to be printed.
|
||||
// Optional. Default value 4KB.
|
||||
StackSize int
|
||||
|
||||
// DisableStackAll disables formatting stack traces of all other goroutines
|
||||
// into buffer after the trace for the current goroutine.
|
||||
// Optional. Default value false.
|
||||
DisableStackAll bool `yaml:"disable_stack_all"`
|
||||
// DisableStackAll disables formatting stack traces of all other goroutines
|
||||
// into buffer after the trace for the current goroutine.
|
||||
// Optional. Default value false.
|
||||
DisableStackAll bool
|
||||
|
||||
// DisablePrintStack disables printing stack trace.
|
||||
// Optional. Default value as false.
|
||||
DisablePrintStack bool `yaml:"disable_print_stack"`
|
||||
// DisablePrintStack disables printing stack trace.
|
||||
// Optional. Default value as false.
|
||||
DisablePrintStack bool
|
||||
}
|
||||
|
||||
// LogLevel is log level to printing stack trace.
|
||||
// Optional. Default value 0 (Print).
|
||||
LogLevel log.Lvl
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultRecoverConfig is the default Recover middleware config.
|
||||
DefaultRecoverConfig = RecoverConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
StackSize: 4 << 10, // 4 KB
|
||||
DisableStackAll: false,
|
||||
DisablePrintStack: false,
|
||||
LogLevel: 0,
|
||||
}
|
||||
)
|
||||
// DefaultRecoverConfig is the default Recover middleware config.
|
||||
var DefaultRecoverConfig = RecoverConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
StackSize: 4 << 10, // 4 KB
|
||||
DisableStackAll: false,
|
||||
DisablePrintStack: false,
|
||||
}
|
||||
|
||||
// Recover returns a middleware which recovers from panics anywhere in the chain
|
||||
// and handles the control to the centralized HTTPErrorHandler.
|
||||
@ -50,9 +40,13 @@ func Recover() echo.MiddlewareFunc {
|
||||
return RecoverWithConfig(DefaultRecoverConfig)
|
||||
}
|
||||
|
||||
// RecoverWithConfig returns a Recover middleware with config.
|
||||
// See: `Recover()`.
|
||||
// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration.
|
||||
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration
|
||||
func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
// Defaults
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultRecoverConfig.Skipper
|
||||
@ -62,40 +56,26 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return func(c echo.Context) (err error) {
|
||||
if config.Skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err, ok := r.(error)
|
||||
tmpErr, ok := r.(error)
|
||||
if !ok {
|
||||
err = fmt.Errorf("%v", r)
|
||||
tmpErr = fmt.Errorf("%v", r)
|
||||
}
|
||||
stack := make([]byte, config.StackSize)
|
||||
length := runtime.Stack(stack, !config.DisableStackAll)
|
||||
if !config.DisablePrintStack {
|
||||
msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])
|
||||
switch config.LogLevel {
|
||||
case log.DEBUG:
|
||||
c.Logger().Debug(msg)
|
||||
case log.INFO:
|
||||
c.Logger().Info(msg)
|
||||
case log.WARN:
|
||||
c.Logger().Warn(msg)
|
||||
case log.ERROR:
|
||||
c.Logger().Error(msg)
|
||||
case log.OFF:
|
||||
// None.
|
||||
default:
|
||||
c.Logger().Print(msg)
|
||||
}
|
||||
stack := make([]byte, config.StackSize)
|
||||
length := runtime.Stack(stack, !config.DisableStackAll)
|
||||
tmpErr = fmt.Errorf("[PANIC RECOVER] %w %s", tmpErr, stack[:length])
|
||||
}
|
||||
c.Error(err)
|
||||
err = tmpErr
|
||||
}
|
||||
}()
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
@ -2,82 +2,109 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRecover(t *testing.T) {
|
||||
e := echo.New()
|
||||
buf := new(bytes.Buffer)
|
||||
e.Logger.SetOutput(buf)
|
||||
e.Logger = &testLogger{output: buf}
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h := Recover()(echo.HandlerFunc(func(c echo.Context) error {
|
||||
h := Recover()(func(c echo.Context) error {
|
||||
panic("test")
|
||||
}))
|
||||
h(c)
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, buf.String(), "PANIC RECOVER")
|
||||
})
|
||||
err := h(c)
|
||||
assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine")
|
||||
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
|
||||
assert.Contains(t, buf.String(), "") // nothing is logged
|
||||
}
|
||||
|
||||
func TestRecoverWithConfig_LogLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
logLevel log.Lvl
|
||||
levelName string
|
||||
}{{
|
||||
logLevel: log.DEBUG,
|
||||
levelName: "DEBUG",
|
||||
}, {
|
||||
logLevel: log.INFO,
|
||||
levelName: "INFO",
|
||||
}, {
|
||||
logLevel: log.WARN,
|
||||
levelName: "WARN",
|
||||
}, {
|
||||
logLevel: log.ERROR,
|
||||
levelName: "ERROR",
|
||||
}, {
|
||||
logLevel: log.OFF,
|
||||
levelName: "OFF",
|
||||
}}
|
||||
func TestRecover_skipper(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.levelName, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
config := RecoverConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
h := RecoverWithConfig(config)(func(c echo.Context) error {
|
||||
panic("testPANIC")
|
||||
})
|
||||
|
||||
var err error
|
||||
assert.Panics(t, func() {
|
||||
err = h(c)
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
|
||||
}
|
||||
|
||||
func TestRecoverWithConfig(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenNoPanic bool
|
||||
whenConfig RecoverConfig
|
||||
expectErrContain string
|
||||
expectErr string
|
||||
}{
|
||||
{
|
||||
name: "ok, default config",
|
||||
whenConfig: DefaultRecoverConfig,
|
||||
expectErrContain: "[PANIC RECOVER] testPANIC goroutine",
|
||||
},
|
||||
{
|
||||
name: "ok, no panic",
|
||||
givenNoPanic: true,
|
||||
whenConfig: DefaultRecoverConfig,
|
||||
expectErrContain: "",
|
||||
},
|
||||
{
|
||||
name: "ok, DisablePrintStack",
|
||||
whenConfig: RecoverConfig{
|
||||
DisablePrintStack: true,
|
||||
},
|
||||
expectErr: "testPANIC",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Logger.SetLevel(log.DEBUG)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
e.Logger.SetOutput(buf)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
config := DefaultRecoverConfig
|
||||
config.LogLevel = tt.logLevel
|
||||
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
|
||||
panic("test")
|
||||
}))
|
||||
config := tc.whenConfig
|
||||
h := RecoverWithConfig(config)(func(c echo.Context) error {
|
||||
if tc.givenNoPanic {
|
||||
return nil
|
||||
}
|
||||
panic("testPANIC")
|
||||
})
|
||||
|
||||
h(c)
|
||||
err := h(c)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
|
||||
output := buf.String()
|
||||
if tt.logLevel == log.OFF {
|
||||
assert.Empty(t, output)
|
||||
if tc.expectErrContain != "" {
|
||||
assert.Contains(t, err.Error(), tc.expectErrContain)
|
||||
} else if tc.expectErr != "" {
|
||||
assert.Contains(t, err.Error(), tc.expectErr)
|
||||
} else {
|
||||
assert.Contains(t, output, "PANIC RECOVER")
|
||||
assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@ -14,7 +15,9 @@ type RedirectConfig struct {
|
||||
|
||||
// Status code to be used when redirecting the request.
|
||||
// Optional. Default value http.StatusMovedPermanently.
|
||||
Code int `yaml:"code"`
|
||||
Code int
|
||||
|
||||
redirect redirectLogic
|
||||
}
|
||||
|
||||
// redirectLogic represents a function that given a scheme, host and uri
|
||||
@ -24,29 +27,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string)
|
||||
|
||||
const www = "www."
|
||||
|
||||
// DefaultRedirectConfig is the default Redirect middleware config.
|
||||
var DefaultRedirectConfig = RedirectConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Code: http.StatusMovedPermanently,
|
||||
}
|
||||
// RedirectHTTPSConfig is the HTTPS Redirect middleware config.
|
||||
var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS}
|
||||
|
||||
// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config.
|
||||
var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW}
|
||||
|
||||
// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config.
|
||||
var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW}
|
||||
|
||||
// RedirectWWWConfig is the WWW Redirect middleware config.
|
||||
var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW}
|
||||
|
||||
// RedirectNonWWWConfig is the non WWW Redirect middleware config.
|
||||
var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW}
|
||||
|
||||
// HTTPSRedirect redirects http requests to https.
|
||||
// For example, http://labstack.com will be redirect to https://labstack.com.
|
||||
//
|
||||
// Usage `Echo#Pre(HTTPSRedirect())`
|
||||
func HTTPSRedirect() echo.MiddlewareFunc {
|
||||
return HTTPSRedirectWithConfig(DefaultRedirectConfig)
|
||||
return HTTPSRedirectWithConfig(RedirectHTTPSConfig)
|
||||
}
|
||||
|
||||
// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config.
|
||||
// See `HTTPSRedirect()`.
|
||||
// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration.
|
||||
func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
|
||||
return redirect(config, func(scheme, host, uri string) (bool, string) {
|
||||
if scheme != "https" {
|
||||
return true, "https://" + host + uri
|
||||
}
|
||||
return false, ""
|
||||
})
|
||||
config.redirect = redirectHTTPS
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// HTTPSWWWRedirect redirects http requests to https www.
|
||||
@ -54,18 +61,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
|
||||
//
|
||||
// Usage `Echo#Pre(HTTPSWWWRedirect())`
|
||||
func HTTPSWWWRedirect() echo.MiddlewareFunc {
|
||||
return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig)
|
||||
return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig)
|
||||
}
|
||||
|
||||
// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
|
||||
// See `HTTPSWWWRedirect()`.
|
||||
// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration.
|
||||
func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
|
||||
return redirect(config, func(scheme, host, uri string) (bool, string) {
|
||||
if scheme != "https" && !strings.HasPrefix(host, www) {
|
||||
return true, "https://www." + host + uri
|
||||
}
|
||||
return false, ""
|
||||
})
|
||||
config.redirect = redirectHTTPSWWW
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// HTTPSNonWWWRedirect redirects http requests to https non www.
|
||||
@ -73,19 +75,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
|
||||
//
|
||||
// Usage `Echo#Pre(HTTPSNonWWWRedirect())`
|
||||
func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
|
||||
return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig)
|
||||
return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig)
|
||||
}
|
||||
|
||||
// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
|
||||
// See `HTTPSNonWWWRedirect()`.
|
||||
// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration.
|
||||
func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
|
||||
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
|
||||
if scheme != "https" {
|
||||
host = strings.TrimPrefix(host, www)
|
||||
return true, "https://" + host + uri
|
||||
}
|
||||
return false, ""
|
||||
})
|
||||
config.redirect = redirectNonHTTPSWWW
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// WWWRedirect redirects non www requests to www.
|
||||
@ -93,18 +89,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
|
||||
//
|
||||
// Usage `Echo#Pre(WWWRedirect())`
|
||||
func WWWRedirect() echo.MiddlewareFunc {
|
||||
return WWWRedirectWithConfig(DefaultRedirectConfig)
|
||||
return WWWRedirectWithConfig(RedirectWWWConfig)
|
||||
}
|
||||
|
||||
// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
|
||||
// See `WWWRedirect()`.
|
||||
// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration.
|
||||
func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
|
||||
return redirect(config, func(scheme, host, uri string) (bool, string) {
|
||||
if !strings.HasPrefix(host, www) {
|
||||
return true, scheme + "://www." + host + uri
|
||||
}
|
||||
return false, ""
|
||||
})
|
||||
config.redirect = redirectWWW
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// NonWWWRedirect redirects www requests to non www.
|
||||
@ -112,26 +103,25 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
|
||||
//
|
||||
// Usage `Echo#Pre(NonWWWRedirect())`
|
||||
func NonWWWRedirect() echo.MiddlewareFunc {
|
||||
return NonWWWRedirectWithConfig(DefaultRedirectConfig)
|
||||
return NonWWWRedirectWithConfig(RedirectNonWWWConfig)
|
||||
}
|
||||
|
||||
// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
|
||||
// See `NonWWWRedirect()`.
|
||||
// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration.
|
||||
func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
|
||||
return redirect(config, func(scheme, host, uri string) (bool, string) {
|
||||
if strings.HasPrefix(host, www) {
|
||||
return true, scheme + "://" + host[4:] + uri
|
||||
}
|
||||
return false, ""
|
||||
})
|
||||
config.redirect = redirectNonWWW
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
|
||||
// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration
|
||||
func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultRedirectConfig.Skipper
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
if config.Code == 0 {
|
||||
config.Code = DefaultRedirectConfig.Code
|
||||
config.Code = http.StatusMovedPermanently
|
||||
}
|
||||
if config.redirect == nil {
|
||||
return nil, errors.New("redirectConfig is missing redirect function")
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@ -142,11 +132,47 @@ func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
|
||||
|
||||
req, scheme := c.Request(), c.Scheme()
|
||||
host := req.Host
|
||||
if ok, url := cb(scheme, host, req.RequestURI); ok {
|
||||
if ok, url := config.redirect(scheme, host, req.RequestURI); ok {
|
||||
return c.Redirect(config.Code, url)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
var redirectHTTPS = func(scheme, host, uri string) (bool, string) {
|
||||
if scheme != "https" {
|
||||
return true, "https://" + host + uri
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) {
|
||||
if scheme != "https" && !strings.HasPrefix(host, www) {
|
||||
return true, "https://www." + host + uri
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) {
|
||||
if scheme != "https" {
|
||||
host = strings.TrimPrefix(host, www)
|
||||
return true, "https://" + host + uri
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
var redirectWWW = func(scheme, host, uri string) (bool, string) {
|
||||
if !strings.HasPrefix(host, www) {
|
||||
return true, scheme + "://www." + host + uri
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
var redirectNonWWW = func(scheme, host, uri string) (bool, string) {
|
||||
if strings.HasPrefix(host, www) {
|
||||
return true, scheme + "://" + host[4:] + uri
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
@ -2,45 +2,38 @@ package middleware
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/random"
|
||||
)
|
||||
|
||||
type (
|
||||
// RequestIDConfig defines the config for RequestID middleware.
|
||||
RequestIDConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// RequestIDConfig defines the config for RequestID middleware.
|
||||
type RequestIDConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Generator defines a function to generate an ID.
|
||||
// Optional. Default value random.String(32).
|
||||
Generator func() string
|
||||
// Generator defines a function to generate an ID.
|
||||
// Optional. Default value random.String(32).
|
||||
Generator func() string
|
||||
|
||||
// RequestIDHandler defines a function which is executed for a request id.
|
||||
RequestIDHandler func(echo.Context, string)
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultRequestIDConfig is the default RequestID middleware config.
|
||||
DefaultRequestIDConfig = RequestIDConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Generator: generator,
|
||||
}
|
||||
)
|
||||
// RequestIDHandler defines a function which is executed for a request id.
|
||||
RequestIDHandler func(c echo.Context, requestID string)
|
||||
}
|
||||
|
||||
// RequestID returns a X-Request-ID middleware.
|
||||
func RequestID() echo.MiddlewareFunc {
|
||||
return RequestIDWithConfig(DefaultRequestIDConfig)
|
||||
return RequestIDWithConfig(RequestIDConfig{})
|
||||
}
|
||||
|
||||
// RequestIDWithConfig returns a X-Request-ID middleware with config.
|
||||
// RequestIDWithConfig returns a X-Request-ID middleware with config or panics on invalid configuration.
|
||||
func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration
|
||||
func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultRequestIDConfig.Skipper
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
if config.Generator == nil {
|
||||
config.Generator = generator
|
||||
config.Generator = createRandomStringGenerator(32)
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@ -62,9 +55,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generator() string {
|
||||
return random.String(32)
|
||||
}, nil
|
||||
}
|
||||
|
@ -18,25 +18,104 @@ func TestRequestID(t *testing.T) {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
rid := RequestIDWithConfig(RequestIDConfig{})
|
||||
rid := RequestID()
|
||||
h := rid(handler)
|
||||
err := h(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
|
||||
}
|
||||
|
||||
func TestMustRequestIDWithConfig_skipper(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusTeapot, "test")
|
||||
})
|
||||
|
||||
generatorCalled := false
|
||||
e.Use(RequestIDWithConfig(RequestIDConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return true
|
||||
},
|
||||
Generator: func() string {
|
||||
generatorCalled = true
|
||||
return "customGenerator"
|
||||
},
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusTeapot, res.Code)
|
||||
assert.Equal(t, "test", res.Body.String())
|
||||
|
||||
assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "")
|
||||
assert.False(t, generatorCalled)
|
||||
}
|
||||
|
||||
func TestMustRequestIDWithConfig_customGenerator(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
handler := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
rid := RequestIDWithConfig(RequestIDConfig{
|
||||
Generator: func() string { return "customGenerator" },
|
||||
})
|
||||
h := rid(handler)
|
||||
err := h(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
|
||||
}
|
||||
|
||||
func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
handler := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
called := false
|
||||
rid := RequestIDWithConfig(RequestIDConfig{
|
||||
Generator: func() string { return "customGenerator" },
|
||||
RequestIDHandler: func(c echo.Context, s string) {
|
||||
called = true
|
||||
},
|
||||
})
|
||||
h := rid(handler)
|
||||
err := h(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestRequestIDWithConfig(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
handler := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
rid, err := RequestIDConfig{}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
h := rid(handler)
|
||||
h(c)
|
||||
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
|
||||
|
||||
// Custom generator and handler
|
||||
customID := "customGenerator"
|
||||
calledHandler := false
|
||||
// Custom generator
|
||||
rid = RequestIDWithConfig(RequestIDConfig{
|
||||
Generator: func() string { return customID },
|
||||
RequestIDHandler: func(_ echo.Context, id string) {
|
||||
calledHandler = true
|
||||
assert.Equal(t, customID, id)
|
||||
},
|
||||
Generator: func() string { return "customGenerator" },
|
||||
})
|
||||
h = rid(handler)
|
||||
h(c)
|
||||
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
|
||||
assert.True(t, calledHandler)
|
||||
}
|
||||
|
||||
func TestRequestID_IDNotAltered(t *testing.T) {
|
||||
|
@ -24,6 +24,7 @@ import (
|
||||
// LogStatus: true,
|
||||
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
|
||||
// logger.Info().
|
||||
// Date("request_start", v.StartTime).
|
||||
// Str("URI", v.URI).
|
||||
// Int("status", v.Status).
|
||||
// Msg("request")
|
||||
@ -39,6 +40,7 @@ import (
|
||||
// LogStatus: true,
|
||||
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
|
||||
// logger.Info("request",
|
||||
// zap.Time("request_start", v.StartTime),
|
||||
// zap.String("URI", v.URI),
|
||||
// zap.Int("status", v.Status),
|
||||
// )
|
||||
@ -54,8 +56,9 @@ import (
|
||||
// LogStatus: true,
|
||||
// LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error {
|
||||
// log.WithFields(logrus.Fields{
|
||||
// "URI": values.URI,
|
||||
// "status": values.Status,
|
||||
// "request_start": values.StartTime,
|
||||
// "URI": values.URI,
|
||||
// "status": values.Status,
|
||||
// }).Info("request")
|
||||
//
|
||||
// return nil
|
||||
@ -158,15 +161,15 @@ type RequestLoggerValues struct {
|
||||
// ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct.
|
||||
ResponseSize int64
|
||||
// Headers are list of headers from request. Note: request can contain more than one header with same value so slice
|
||||
// of values is been logger for each given header.
|
||||
// of values is what will be returned/logged for each given header.
|
||||
// Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
|
||||
// names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
|
||||
Headers map[string][]string
|
||||
// QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter
|
||||
// with same name so slice of values is been logger for each given query param name.
|
||||
// with same name so slice of values is what will be returned/logged for each given query param name.
|
||||
QueryParams map[string][]string
|
||||
// FormValues are list of form values from request body+URI. Note: request can contain more than one form value with
|
||||
// same name so slice of values is been logger for each given form value name.
|
||||
// same name so slice of values is what will be returned/logged for each given form value name.
|
||||
FormValues map[string][]string
|
||||
}
|
||||
|
||||
|
@ -289,7 +289,7 @@ func TestRequestLogger_allFields(t *testing.T) {
|
||||
req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c := e.NewContext(req, rec).(echo.EditableContext)
|
||||
|
||||
c.SetPath("/test*")
|
||||
|
||||
|
@ -1,62 +1,58 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
// RewriteConfig defines the config for Rewrite middleware.
|
||||
RewriteConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// RewriteConfig defines the config for Rewrite middleware.
|
||||
type RewriteConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Rules defines the URL path rewrite rules. The values captured in asterisk can be
|
||||
// retrieved by index e.g. $1, $2 and so on.
|
||||
// Example:
|
||||
// "/old": "/new",
|
||||
// "/api/*": "/$1",
|
||||
// "/js/*": "/public/javascripts/$1",
|
||||
// "/users/*/orders/*": "/user/$1/order/$2",
|
||||
// Required.
|
||||
Rules map[string]string `yaml:"rules"`
|
||||
// Rules defines the URL path rewrite rules. The values captured in asterisk can be
|
||||
// retrieved by index e.g. $1, $2 and so on.
|
||||
// Example:
|
||||
// "/old": "/new",
|
||||
// "/api/*": "/$1",
|
||||
// "/js/*": "/public/javascripts/$1",
|
||||
// "/users/*/orders/*": "/user/$1/order/$2",
|
||||
// Required.
|
||||
Rules map[string]string
|
||||
|
||||
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
|
||||
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
|
||||
// Example:
|
||||
// "^/old/[0.9]+/": "/new",
|
||||
// "^/api/.+?/(.*)": "/v2/$1",
|
||||
RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"`
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultRewriteConfig is the default Rewrite middleware config.
|
||||
DefaultRewriteConfig = RewriteConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
}
|
||||
)
|
||||
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
|
||||
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
|
||||
// Example:
|
||||
// "^/old/[0.9]+/": "/new",
|
||||
// "^/api/.+?/(.*)": "/v2/$1",
|
||||
RegexRules map[*regexp.Regexp]string
|
||||
}
|
||||
|
||||
// Rewrite returns a Rewrite middleware.
|
||||
//
|
||||
// Rewrite middleware rewrites the URL path based on the provided rules.
|
||||
func Rewrite(rules map[string]string) echo.MiddlewareFunc {
|
||||
c := DefaultRewriteConfig
|
||||
c := RewriteConfig{}
|
||||
c.Rules = rules
|
||||
return RewriteWithConfig(c)
|
||||
}
|
||||
|
||||
// RewriteWithConfig returns a Rewrite middleware with config.
|
||||
// See: `Rewrite()`.
|
||||
// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration.
|
||||
//
|
||||
// Rewrite middleware rewrites the URL path based on the provided rules.
|
||||
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
if config.Rules == nil && config.RegexRules == nil {
|
||||
panic("echo: rewrite middleware requires url path rewrite rules or regex rules")
|
||||
}
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration
|
||||
func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultBodyDumpConfig.Skipper
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
if config.Rules == nil && config.RegexRules == nil {
|
||||
return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules")
|
||||
}
|
||||
|
||||
if config.RegexRules == nil {
|
||||
@ -77,5 +73,5 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
@ -24,10 +24,10 @@ func TestRewriteAfterRouting(t *testing.T) {
|
||||
},
|
||||
}))
|
||||
e.GET("/public/*", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, c.Param("*"))
|
||||
return c.String(http.StatusOK, c.PathParam("*"))
|
||||
})
|
||||
e.GET("/*", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, c.Param("*"))
|
||||
return c.String(http.StatusOK, c.PathParam("*"))
|
||||
})
|
||||
|
||||
var testCases = []struct {
|
||||
@ -90,20 +90,74 @@ func TestRewriteAfterRouting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
RewriteWithConfig(RewriteConfig{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMustRewriteWithConfig_skipper(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenSkipper func(c echo.Context) bool
|
||||
whenURL string
|
||||
expectURL string
|
||||
expectStatus int
|
||||
}{
|
||||
{
|
||||
name: "not skipped",
|
||||
whenURL: "/old",
|
||||
expectURL: "/new",
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "skipped",
|
||||
givenSkipper: func(c echo.Context) bool {
|
||||
return true
|
||||
},
|
||||
whenURL: "/old",
|
||||
expectURL: "/old",
|
||||
expectStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
e.Pre(RewriteWithConfig(
|
||||
RewriteConfig{
|
||||
Skipper: tc.givenSkipper,
|
||||
Rules: map[string]string{"/old": "/new"}},
|
||||
))
|
||||
|
||||
e.GET("/new", func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, tc.expectURL, req.URL.EscapedPath())
|
||||
assert.Equal(t, tc.expectStatus, rec.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Issue #1086
|
||||
func TestEchoRewritePreMiddleware(t *testing.T) {
|
||||
e := echo.New()
|
||||
r := e.Router()
|
||||
|
||||
// Rewrite old url to new one
|
||||
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
|
||||
e.Pre(Rewrite(map[string]string{
|
||||
"/old": "/new",
|
||||
},
|
||||
))
|
||||
e.Pre(RewriteWithConfig(RewriteConfig{
|
||||
Rules: map[string]string{"/old": "/new"}}),
|
||||
)
|
||||
|
||||
// Route
|
||||
r.Add(http.MethodGet, "/new", func(c echo.Context) error {
|
||||
e.Add(http.MethodGet, "/new", func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
@ -117,7 +171,6 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
|
||||
// Issue #1143
|
||||
func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
|
||||
e := echo.New()
|
||||
r := e.Router()
|
||||
|
||||
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
|
||||
e.Pre(RewriteWithConfig(RewriteConfig{
|
||||
@ -127,10 +180,10 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
|
||||
},
|
||||
}))
|
||||
|
||||
r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
|
||||
e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "hosts")
|
||||
})
|
||||
r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
|
||||
e.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "eng")
|
||||
})
|
||||
|
||||
|
@ -6,84 +6,80 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
// SecureConfig defines the config for Secure middleware.
|
||||
SecureConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// SecureConfig defines the config for Secure middleware.
|
||||
type SecureConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// XSSProtection provides protection against cross-site scripting attack (XSS)
|
||||
// by setting the `X-XSS-Protection` header.
|
||||
// Optional. Default value "1; mode=block".
|
||||
XSSProtection string `yaml:"xss_protection"`
|
||||
// XSSProtection provides protection against cross-site scripting attack (XSS)
|
||||
// by setting the `X-XSS-Protection` header.
|
||||
// Optional. Default value "1; mode=block".
|
||||
XSSProtection string
|
||||
|
||||
// ContentTypeNosniff provides protection against overriding Content-Type
|
||||
// header by setting the `X-Content-Type-Options` header.
|
||||
// Optional. Default value "nosniff".
|
||||
ContentTypeNosniff string `yaml:"content_type_nosniff"`
|
||||
// ContentTypeNosniff provides protection against overriding Content-Type
|
||||
// header by setting the `X-Content-Type-Options` header.
|
||||
// Optional. Default value "nosniff".
|
||||
ContentTypeNosniff string
|
||||
|
||||
// XFrameOptions can be used to indicate whether or not a browser should
|
||||
// be allowed to render a page in a <frame>, <iframe> or <object> .
|
||||
// Sites can use this to avoid clickjacking attacks, by ensuring that their
|
||||
// content is not embedded into other sites.provides protection against
|
||||
// clickjacking.
|
||||
// Optional. Default value "SAMEORIGIN".
|
||||
// Possible values:
|
||||
// - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself.
|
||||
// - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so.
|
||||
// - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin.
|
||||
XFrameOptions string `yaml:"x_frame_options"`
|
||||
// XFrameOptions can be used to indicate whether or not a browser should
|
||||
// be allowed to render a page in a <frame>, <iframe> or <object> .
|
||||
// Sites can use this to avoid clickjacking attacks, by ensuring that their
|
||||
// content is not embedded into other sites.provides protection against
|
||||
// clickjacking.
|
||||
// Optional. Default value "SAMEORIGIN".
|
||||
// Possible values:
|
||||
// - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself.
|
||||
// - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so.
|
||||
// - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin.
|
||||
XFrameOptions string
|
||||
|
||||
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how
|
||||
// long (in seconds) browsers should remember that this site is only to
|
||||
// be accessed using HTTPS. This reduces your exposure to some SSL-stripping
|
||||
// man-in-the-middle (MITM) attacks.
|
||||
// Optional. Default value 0.
|
||||
HSTSMaxAge int `yaml:"hsts_max_age"`
|
||||
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how
|
||||
// long (in seconds) browsers should remember that this site is only to
|
||||
// be accessed using HTTPS. This reduces your exposure to some SSL-stripping
|
||||
// man-in-the-middle (MITM) attacks.
|
||||
// Optional. Default value 0.
|
||||
HSTSMaxAge int
|
||||
|
||||
// HSTSExcludeSubdomains won't include subdomains tag in the `Strict Transport Security`
|
||||
// header, excluding all subdomains from security policy. It has no effect
|
||||
// unless HSTSMaxAge is set to a non-zero value.
|
||||
// Optional. Default value false.
|
||||
HSTSExcludeSubdomains bool `yaml:"hsts_exclude_subdomains"`
|
||||
// HSTSExcludeSubdomains won't include subdomains tag in the `Strict Transport Security`
|
||||
// header, excluding all subdomains from security policy. It has no effect
|
||||
// unless HSTSMaxAge is set to a non-zero value.
|
||||
// Optional. Default value false.
|
||||
HSTSExcludeSubdomains bool
|
||||
|
||||
// ContentSecurityPolicy sets the `Content-Security-Policy` header providing
|
||||
// security against cross-site scripting (XSS), clickjacking and other code
|
||||
// injection attacks resulting from execution of malicious content in the
|
||||
// trusted web page context.
|
||||
// Optional. Default value "".
|
||||
ContentSecurityPolicy string `yaml:"content_security_policy"`
|
||||
// ContentSecurityPolicy sets the `Content-Security-Policy` header providing
|
||||
// security against cross-site scripting (XSS), clickjacking and other code
|
||||
// injection attacks resulting from execution of malicious content in the
|
||||
// trusted web page context.
|
||||
// Optional. Default value "".
|
||||
ContentSecurityPolicy string
|
||||
|
||||
// CSPReportOnly would use the `Content-Security-Policy-Report-Only` header instead
|
||||
// of the `Content-Security-Policy` header. This allows iterative updates of the
|
||||
// content security policy by only reporting the violations that would
|
||||
// have occurred instead of blocking the resource.
|
||||
// Optional. Default value false.
|
||||
CSPReportOnly bool `yaml:"csp_report_only"`
|
||||
// CSPReportOnly would use the `Content-Security-Policy-Report-Only` header instead
|
||||
// of the `Content-Security-Policy` header. This allows iterative updates of the
|
||||
// content security policy by only reporting the violations that would
|
||||
// have occurred instead of blocking the resource.
|
||||
// Optional. Default value false.
|
||||
CSPReportOnly bool
|
||||
|
||||
// HSTSPreloadEnabled will add the preload tag in the `Strict Transport Security`
|
||||
// header, which enables the domain to be included in the HSTS preload list
|
||||
// maintained by Chrome (and used by Firefox and Safari): https://hstspreload.org/
|
||||
// Optional. Default value false.
|
||||
HSTSPreloadEnabled bool `yaml:"hsts_preload_enabled"`
|
||||
// HSTSPreloadEnabled will add the preload tag in the `Strict Transport Security`
|
||||
// header, which enables the domain to be included in the HSTS preload list
|
||||
// maintained by Chrome (and used by Firefox and Safari): https://hstspreload.org/
|
||||
// Optional. Default value false.
|
||||
HSTSPreloadEnabled bool
|
||||
|
||||
// ReferrerPolicy sets the `Referrer-Policy` header providing security against
|
||||
// leaking potentially sensitive request paths to third parties.
|
||||
// Optional. Default value "".
|
||||
ReferrerPolicy string `yaml:"referrer_policy"`
|
||||
}
|
||||
)
|
||||
// ReferrerPolicy sets the `Referrer-Policy` header providing security against
|
||||
// leaking potentially sensitive request paths to third parties.
|
||||
// Optional. Default value "".
|
||||
ReferrerPolicy string
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultSecureConfig is the default Secure middleware config.
|
||||
DefaultSecureConfig = SecureConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
XSSProtection: "1; mode=block",
|
||||
ContentTypeNosniff: "nosniff",
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
HSTSPreloadEnabled: false,
|
||||
}
|
||||
)
|
||||
// DefaultSecureConfig is the default Secure middleware config.
|
||||
var DefaultSecureConfig = SecureConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
XSSProtection: "1; mode=block",
|
||||
ContentTypeNosniff: "nosniff",
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
HSTSPreloadEnabled: false,
|
||||
}
|
||||
|
||||
// Secure returns a Secure middleware.
|
||||
// Secure middleware provides protection against cross-site scripting (XSS) attack,
|
||||
@ -93,9 +89,13 @@ func Secure() echo.MiddlewareFunc {
|
||||
return SecureWithConfig(DefaultSecureConfig)
|
||||
}
|
||||
|
||||
// SecureWithConfig returns a Secure middleware with config.
|
||||
// See: `Secure()`.
|
||||
// SecureWithConfig returns a Secure middleware with config or panics on invalid configuration.
|
||||
func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts SecureConfig to middleware or returns an error for invalid configuration
|
||||
func (config SecureConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
// Defaults
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultSecureConfig.Skipper
|
||||
@ -141,5 +141,5 @@ func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
@ -19,26 +19,40 @@ func TestSecure(t *testing.T) {
|
||||
}
|
||||
|
||||
// Default
|
||||
Secure()(h)(c)
|
||||
err := Secure()(h)(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "1; mode=block", rec.Header().Get(echo.HeaderXXSSProtection))
|
||||
assert.Equal(t, "nosniff", rec.Header().Get(echo.HeaderXContentTypeOptions))
|
||||
assert.Equal(t, "SAMEORIGIN", rec.Header().Get(echo.HeaderXFrameOptions))
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderStrictTransportSecurity))
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy))
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderReferrerPolicy))
|
||||
}
|
||||
|
||||
// Custom
|
||||
func TestSecureWithConfig(t *testing.T) {
|
||||
e := echo.New()
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderXForwardedProto, "https")
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
SecureWithConfig(SecureConfig{
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
mw, err := SecureConfig{
|
||||
XSSProtection: "",
|
||||
ContentTypeNosniff: "",
|
||||
XFrameOptions: "",
|
||||
HSTSMaxAge: 3600,
|
||||
ContentSecurityPolicy: "default-src 'self'",
|
||||
ReferrerPolicy: "origin",
|
||||
})(h)(c)
|
||||
}.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = mw(h)(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection))
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions))
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions))
|
||||
@ -47,11 +61,21 @@ func TestSecure(t *testing.T) {
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly))
|
||||
assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy))
|
||||
|
||||
}
|
||||
|
||||
func TestSecureWithConfig_CSPReportOnly(t *testing.T) {
|
||||
// Custom with CSPReportOnly flag
|
||||
e := echo.New()
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderXForwardedProto, "https")
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
SecureWithConfig(SecureConfig{
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := SecureWithConfig(SecureConfig{
|
||||
XSSProtection: "",
|
||||
ContentTypeNosniff: "",
|
||||
XFrameOptions: "",
|
||||
@ -60,6 +84,8 @@ func TestSecure(t *testing.T) {
|
||||
CSPReportOnly: true,
|
||||
ReferrerPolicy: "origin",
|
||||
})(h)(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection))
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions))
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions))
|
||||
@ -67,25 +93,51 @@ func TestSecure(t *testing.T) {
|
||||
assert.Equal(t, "default-src 'self'", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly))
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy))
|
||||
assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy))
|
||||
}
|
||||
|
||||
func TestSecureWithConfig_HSTSPreloadEnabled(t *testing.T) {
|
||||
// Custom with CSPReportOnly flag
|
||||
e := echo.New()
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
// Custom, with preload option enabled
|
||||
req.Header.Set(echo.HeaderXForwardedProto, "https")
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
SecureWithConfig(SecureConfig{
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := SecureWithConfig(SecureConfig{
|
||||
HSTSMaxAge: 3600,
|
||||
HSTSPreloadEnabled: true,
|
||||
})(h)(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "max-age=3600; includeSubdomains; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity))
|
||||
|
||||
}
|
||||
|
||||
func TestSecureWithConfig_HSTSExcludeSubdomains(t *testing.T) {
|
||||
// Custom with CSPReportOnly flag
|
||||
e := echo.New()
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
// Custom, with preload option enabled and subdomains excluded
|
||||
req.Header.Set(echo.HeaderXForwardedProto, "https")
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
SecureWithConfig(SecureConfig{
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := SecureWithConfig(SecureConfig{
|
||||
HSTSMaxAge: 3600,
|
||||
HSTSPreloadEnabled: true,
|
||||
HSTSExcludeSubdomains: true,
|
||||
})(h)(c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "max-age=3600; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity))
|
||||
}
|
||||
|
@ -1,44 +1,45 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type (
|
||||
// TrailingSlashConfig defines the config for TrailingSlash middleware.
|
||||
TrailingSlashConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// AddTrailingSlashConfig is the middleware config for adding trailing slash to the request.
|
||||
type AddTrailingSlashConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Status code to be used when redirecting the request.
|
||||
// Optional, but when provided the request is redirected using this code.
|
||||
RedirectCode int `yaml:"redirect_code"`
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultTrailingSlashConfig is the default TrailingSlash middleware config.
|
||||
DefaultTrailingSlashConfig = TrailingSlashConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
}
|
||||
)
|
||||
// Status code to be used when redirecting the request.
|
||||
// Optional, but when provided the request is redirected using this code.
|
||||
// Valid status codes: [300...308]
|
||||
RedirectCode int
|
||||
}
|
||||
|
||||
// AddTrailingSlash returns a root level (before router) middleware which adds a
|
||||
// trailing slash to the request `URL#Path`.
|
||||
//
|
||||
// Usage `Echo#Pre(AddTrailingSlash())`
|
||||
func AddTrailingSlash() echo.MiddlewareFunc {
|
||||
return AddTrailingSlashWithConfig(DefaultTrailingSlashConfig)
|
||||
return AddTrailingSlashWithConfig(AddTrailingSlashConfig{})
|
||||
}
|
||||
|
||||
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config.
|
||||
// See `AddTrailingSlash()`.
|
||||
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config or panics on invalid configuration.
|
||||
func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc {
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts AddTrailingSlashConfig to middleware or returns an error for invalid configuration
|
||||
func (config AddTrailingSlashConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultTrailingSlashConfig.Skipper
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
if config.RedirectCode != 0 && (config.RedirectCode < http.StatusMultipleChoices || config.RedirectCode > http.StatusPermanentRedirect) {
|
||||
// this is same check as `echo.context.Redirect()` does, but we can check this before even serving the request.
|
||||
return nil, errors.New("invalid redirect code for add trailing slash middleware")
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@ -69,7 +70,17 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RemoveTrailingSlashConfig is the middleware config for removing trailing slash from the request.
|
||||
type RemoveTrailingSlashConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Status code to be used when redirecting the request.
|
||||
// Optional, but when provided the request is redirected using this code.
|
||||
RedirectCode int
|
||||
}
|
||||
|
||||
// RemoveTrailingSlash returns a root level (before router) middleware which removes
|
||||
@ -77,15 +88,22 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
|
||||
//
|
||||
// Usage `Echo#Pre(RemoveTrailingSlash())`
|
||||
func RemoveTrailingSlash() echo.MiddlewareFunc {
|
||||
return RemoveTrailingSlashWithConfig(TrailingSlashConfig{})
|
||||
return RemoveTrailingSlashWithConfig(RemoveTrailingSlashConfig{})
|
||||
}
|
||||
|
||||
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config.
|
||||
// See `RemoveTrailingSlash()`.
|
||||
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config or panics on invalid configuration.
|
||||
func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc {
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts RemoveTrailingSlashConfig to middleware or returns an error for invalid configuration
|
||||
func (config RemoveTrailingSlashConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultTrailingSlashConfig.Skipper
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
if config.RedirectCode != 0 && (config.RedirectCode < http.StatusMultipleChoices || config.RedirectCode > http.StatusPermanentRedirect) {
|
||||
// this is same check as `echo.context.Redirect()` does, but we can check this before even serving the request.
|
||||
return nil, errors.New("invalid redirect code for remove trailing slash middleware")
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@ -117,7 +135,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func sanitizeURI(uri string) string {
|
||||
|
@ -67,7 +67,7 @@ func TestAddTrailingSlashWithConfig(t *testing.T) {
|
||||
t.Run(tc.whenURL, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
mw := AddTrailingSlashWithConfig(TrailingSlashConfig{
|
||||
mw := AddTrailingSlashWithConfig(AddTrailingSlashConfig{
|
||||
RedirectCode: http.StatusMovedPermanently,
|
||||
})
|
||||
h := mw(func(c echo.Context) error {
|
||||
@ -203,7 +203,7 @@ func TestRemoveTrailingSlashWithConfig(t *testing.T) {
|
||||
t.Run(tc.whenURL, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
mw := RemoveTrailingSlashWithConfig(TrailingSlashConfig{
|
||||
mw := RemoveTrailingSlashWithConfig(RemoveTrailingSlashConfig{
|
||||
RedirectCode: http.StatusMovedPermanently,
|
||||
})
|
||||
h := mw(func(c echo.Context) error {
|
||||
|
@ -1,55 +1,65 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/bytes"
|
||||
)
|
||||
|
||||
type (
|
||||
// StaticConfig defines the config for Static middleware.
|
||||
StaticConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// StaticConfig defines the config for Static middleware.
|
||||
type StaticConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// Root directory from where the static content is served.
|
||||
// Required.
|
||||
Root string `yaml:"root"`
|
||||
// Root directory from where the static content is served (relative to given Filesystem).
|
||||
// `Root: "."` means root folder from Filesystem.
|
||||
// Required.
|
||||
Root string
|
||||
|
||||
// Index file for serving a directory.
|
||||
// Optional. Default value "index.html".
|
||||
Index string `yaml:"index"`
|
||||
// Filesystem provides access to the static content.
|
||||
// Optional. Defaults to echo.Filesystem (serves files from `.` folder where executable is started)
|
||||
Filesystem fs.FS
|
||||
|
||||
// Enable HTML5 mode by forwarding all not-found requests to root so that
|
||||
// SPA (single-page application) can handle the routing.
|
||||
// Optional. Default value false.
|
||||
HTML5 bool `yaml:"html5"`
|
||||
// Index file for serving a directory.
|
||||
// Optional. Default value "index.html".
|
||||
Index string
|
||||
|
||||
// Enable directory browsing.
|
||||
// Optional. Default value false.
|
||||
Browse bool `yaml:"browse"`
|
||||
// Enable HTML5 mode by forwarding all not-found requests to root so that
|
||||
// SPA (single-page application) can handle the routing.
|
||||
// Optional. Default value false.
|
||||
HTML5 bool
|
||||
|
||||
// Enable ignoring of the base of the URL path.
|
||||
// Example: when assigning a static middleware to a non root path group,
|
||||
// the filesystem path is not doubled
|
||||
// Optional. Default value false.
|
||||
IgnoreBase bool `yaml:"ignoreBase"`
|
||||
// Enable directory browsing.
|
||||
// Optional. Default value false.
|
||||
Browse bool
|
||||
|
||||
// Filesystem provides access to the static content.
|
||||
// Optional. Defaults to http.Dir(config.Root)
|
||||
Filesystem http.FileSystem `yaml:"-"`
|
||||
}
|
||||
)
|
||||
// Enable ignoring of the base of the URL path.
|
||||
// Example: when assigning a static middleware to a non root path group,
|
||||
// the filesystem path is not doubled
|
||||
// Optional. Default value false.
|
||||
IgnoreBase bool
|
||||
|
||||
const html = `
|
||||
// DisablePathUnescaping disables path parameter (param: *) unescaping. This is useful when router is set to unescape
|
||||
// all parameter and doing it again in this middleware would corrupt filename that is requested.
|
||||
DisablePathUnescaping bool
|
||||
|
||||
// DirectoryListTemplate is template to list directory contents
|
||||
// Optional. Default to `directoryListHTMLTemplate` constant below.
|
||||
DirectoryListTemplate string
|
||||
}
|
||||
|
||||
const directoryListHTMLTemplate = `
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
@ -121,25 +131,26 @@ const html = `
|
||||
</html>
|
||||
`
|
||||
|
||||
var (
|
||||
// DefaultStaticConfig is the default Static middleware config.
|
||||
DefaultStaticConfig = StaticConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Index: "index.html",
|
||||
}
|
||||
)
|
||||
// DefaultStaticConfig is the default Static middleware config.
|
||||
var DefaultStaticConfig = StaticConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Index: "index.html",
|
||||
}
|
||||
|
||||
// Static returns a Static middleware to serves static content from the provided
|
||||
// root directory.
|
||||
// Static returns a Static middleware to serves static content from the provided root directory.
|
||||
func Static(root string) echo.MiddlewareFunc {
|
||||
c := DefaultStaticConfig
|
||||
c.Root = root
|
||||
return StaticWithConfig(c)
|
||||
}
|
||||
|
||||
// StaticWithConfig returns a Static middleware with config.
|
||||
// See `Static()`.
|
||||
// StaticWithConfig returns a Static middleware to serves static content or panics on invalid configuration.
|
||||
func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts StaticConfig to middleware or returns an error for invalid configuration
|
||||
func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
// Defaults
|
||||
if config.Root == "" {
|
||||
config.Root = "." // For security we want to restrict to CWD.
|
||||
@ -150,30 +161,32 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
|
||||
if config.Index == "" {
|
||||
config.Index = DefaultStaticConfig.Index
|
||||
}
|
||||
if config.Filesystem == nil {
|
||||
config.Filesystem = http.Dir(config.Root)
|
||||
config.Root = "."
|
||||
if config.DirectoryListTemplate == "" {
|
||||
config.DirectoryListTemplate = directoryListHTMLTemplate
|
||||
}
|
||||
|
||||
// Index template
|
||||
t, err := template.New("index").Parse(html)
|
||||
dirListTemplate, err := template.New("index").Parse(config.DirectoryListTemplate)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("echo: %v", err))
|
||||
return nil, fmt.Errorf("echo static middleware directory list template parsing error: %w", err)
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) (err error) {
|
||||
return func(c echo.Context) error {
|
||||
if config.Skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
p := c.Request().URL.Path
|
||||
if strings.HasSuffix(c.Path(), "*") { // When serving from a group, e.g. `/static*`.
|
||||
p = c.Param("*")
|
||||
pathUnescape := true
|
||||
if c.RouteMatchType() == echo.RouteMatchFound && strings.HasSuffix(c.Path(), "*") { // When serving from a group, e.g. `/static*`.
|
||||
p = c.PathParam("*")
|
||||
pathUnescape = !config.DisablePathUnescaping // because router could already do PathUnescape
|
||||
}
|
||||
p, err = url.PathUnescape(p)
|
||||
if err != nil {
|
||||
return
|
||||
if pathUnescape {
|
||||
p, err = url.PathUnescape(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security
|
||||
|
||||
@ -186,22 +199,29 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
file, err := openFile(config.Filesystem, name)
|
||||
currentFS := config.Filesystem
|
||||
if currentFS == nil {
|
||||
currentFS = c.Echo().Filesystem
|
||||
}
|
||||
|
||||
file, err := openFile(currentFS, name)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = next(c); err == nil {
|
||||
return err
|
||||
// when file does not exist let handler to handle that request. if it succeeds then we are done
|
||||
err = next(c)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
he, ok := err.(*echo.HTTPError)
|
||||
if !(ok && config.HTML5 && he.Code == http.StatusNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
file, err = openFile(config.Filesystem, filepath.Join(config.Root, config.Index))
|
||||
// is case HTML5 mode is enabled + echo 404 we serve index to the client
|
||||
file, err = openFile(currentFS, filepath.Join(config.Root, config.Index))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -215,10 +235,10 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
index, err := openFile(config.Filesystem, filepath.Join(name, config.Index))
|
||||
index, err := openFile(currentFS, filepath.Join(name, config.Index))
|
||||
if err != nil {
|
||||
if config.Browse {
|
||||
return listDir(t, name, file, c.Response())
|
||||
return listDir(dirListTemplate, name, currentFS, file, c.Response())
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
@ -238,25 +258,24 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
|
||||
|
||||
return serveFile(c, file, info)
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func openFile(fs http.FileSystem, name string) (http.File, error) {
|
||||
func openFile(fs fs.FS, name string) (fs.File, error) {
|
||||
pathWithSlashes := filepath.ToSlash(name)
|
||||
return fs.Open(pathWithSlashes)
|
||||
}
|
||||
|
||||
func serveFile(c echo.Context, file http.File, info os.FileInfo) error {
|
||||
http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), file)
|
||||
func serveFile(c echo.Context, file fs.File, info os.FileInfo) error {
|
||||
ff, ok := file.(io.ReadSeeker)
|
||||
if !ok {
|
||||
return errors.New("file does not implement io.ReadSeeker")
|
||||
}
|
||||
http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), ff)
|
||||
return nil
|
||||
}
|
||||
|
||||
func listDir(t *template.Template, name string, dir http.File, res *echo.Response) (err error) {
|
||||
files, err := dir.Readdir(-1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
func listDir(t *template.Template, name string, filesystem fs.FS, dir fs.File, res *echo.Response) error {
|
||||
// Create directory index
|
||||
res.Header().Set(echo.HeaderContentType, echo.MIMETextHTMLCharsetUTF8)
|
||||
data := struct {
|
||||
@ -265,12 +284,60 @@ func listDir(t *template.Template, name string, dir http.File, res *echo.Respons
|
||||
}{
|
||||
Name: name,
|
||||
}
|
||||
for _, f := range files {
|
||||
err := fs.WalkDir(filesystem, ".", func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info, infoErr := d.Info()
|
||||
if infoErr != nil {
|
||||
return fmt.Errorf("static middleware list dir error when getting file info: %w", err)
|
||||
}
|
||||
data.Files = append(data.Files, struct {
|
||||
Name string
|
||||
Dir bool
|
||||
Size string
|
||||
}{f.Name(), f.IsDir(), bytes.Format(f.Size())})
|
||||
}{d.Name(), d.IsDir(), format(info.Size())})
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return t.Execute(res, data)
|
||||
}
|
||||
|
||||
// format formats bytes integer to human readable string.
|
||||
// For example, 31323 bytes will return 30.59KB.
|
||||
func format(b int64) string {
|
||||
multiple := ""
|
||||
value := float64(b)
|
||||
|
||||
switch {
|
||||
case b >= EB:
|
||||
value /= float64(EB)
|
||||
multiple = "EB"
|
||||
case b >= PB:
|
||||
value /= float64(PB)
|
||||
multiple = "PB"
|
||||
case b >= TB:
|
||||
value /= float64(TB)
|
||||
multiple = "TB"
|
||||
case b >= GB:
|
||||
value /= float64(GB)
|
||||
multiple = "GB"
|
||||
case b >= MB:
|
||||
value /= float64(MB)
|
||||
multiple = "MB"
|
||||
case b >= KB:
|
||||
value /= float64(KB)
|
||||
multiple = "KB"
|
||||
case b == 0:
|
||||
return "0"
|
||||
default:
|
||||
return strconv.FormatInt(b, 10) + "B"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%.2f%s", value, multiple)
|
||||
}
|
||||
|
@ -81,14 +81,16 @@ func TestStatic_CustomFS(t *testing.T) {
|
||||
|
||||
config := StaticConfig{
|
||||
Root: ".",
|
||||
Filesystem: http.FS(tc.filesystem),
|
||||
Filesystem: tc.filesystem,
|
||||
}
|
||||
|
||||
if tc.root != "" {
|
||||
config.Root = tc.root
|
||||
}
|
||||
|
||||
middlewareFunc := StaticWithConfig(config)
|
||||
middlewareFunc, err := config.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
e.Use(middlewareFunc)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
|
@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -10,6 +11,37 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStatic_useCaseForApiAndSPAs(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
// serve single page application (SPA) files from server root
|
||||
e.Use(StaticWithConfig(StaticConfig{
|
||||
Root: ".",
|
||||
// by default Echo filesystem is fixed to `./` but this does not allow `../` (moving up in folder structure past filesystem root)
|
||||
Filesystem: os.DirFS("../_fixture"),
|
||||
}))
|
||||
|
||||
// all requests to `/api/*` will end up in echo handlers (assuming there is not `api` folder and files)
|
||||
api := e.Group("/api")
|
||||
users := api.Group("/users")
|
||||
users.GET("/info", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "users info")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/users/info", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "users info", rec.Body.String())
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/index.html", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "<title>Echo</title>")
|
||||
|
||||
}
|
||||
|
||||
func TestStatic(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
@ -35,7 +67,7 @@ func TestStatic(t *testing.T) {
|
||||
{
|
||||
name: "ok, when html5 mode serve index for any static file that does not exist",
|
||||
givenConfig: &StaticConfig{
|
||||
Root: "../_fixture",
|
||||
Root: "_fixture",
|
||||
HTML5: true,
|
||||
},
|
||||
whenURL: "/random",
|
||||
@ -45,7 +77,7 @@ func TestStatic(t *testing.T) {
|
||||
{
|
||||
name: "ok, serve index as directory index listing files directory",
|
||||
givenConfig: &StaticConfig{
|
||||
Root: "../_fixture/certs",
|
||||
Root: "_fixture/certs",
|
||||
Browse: true,
|
||||
},
|
||||
whenURL: "/",
|
||||
@ -55,7 +87,7 @@ func TestStatic(t *testing.T) {
|
||||
{
|
||||
name: "ok, serve directory index with IgnoreBase and browse",
|
||||
givenConfig: &StaticConfig{
|
||||
Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
|
||||
Root: "_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
|
||||
IgnoreBase: true,
|
||||
Browse: true,
|
||||
},
|
||||
@ -67,7 +99,7 @@ func TestStatic(t *testing.T) {
|
||||
{
|
||||
name: "ok, serve file with IgnoreBase",
|
||||
givenConfig: &StaticConfig{
|
||||
Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
|
||||
Root: "_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
|
||||
IgnoreBase: true,
|
||||
Browse: true,
|
||||
},
|
||||
@ -95,15 +127,27 @@ func TestStatic(t *testing.T) {
|
||||
expectContains: "{\"message\":\"Not Found\"}\n",
|
||||
},
|
||||
{
|
||||
name: "ok, do not serve file, when a handler took care of the request",
|
||||
name: "ok, when no file then a handler will care of the request",
|
||||
whenURL: "/regular-handler",
|
||||
expectCode: http.StatusOK,
|
||||
expectContains: "ok",
|
||||
},
|
||||
{
|
||||
name: "ok, skip middleware and serve handler",
|
||||
givenConfig: &StaticConfig{
|
||||
Root: "_fixture/images/",
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
whenURL: "/walle.png",
|
||||
expectCode: http.StatusTeapot,
|
||||
expectContains: "walle",
|
||||
},
|
||||
{
|
||||
name: "nok, when html5 fail if the index file does not exist",
|
||||
givenConfig: &StaticConfig{
|
||||
Root: "../_fixture",
|
||||
Root: "_fixture",
|
||||
HTML5: true,
|
||||
Index: "missing.html",
|
||||
},
|
||||
@ -114,7 +158,7 @@ func TestStatic(t *testing.T) {
|
||||
name: "ok, serve from http.FileSystem",
|
||||
givenConfig: &StaticConfig{
|
||||
Root: "_fixture",
|
||||
Filesystem: http.Dir(".."),
|
||||
Filesystem: os.DirFS(".."),
|
||||
},
|
||||
whenURL: "/",
|
||||
expectCode: http.StatusOK,
|
||||
@ -125,8 +169,9 @@ func TestStatic(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Filesystem = os.DirFS("../")
|
||||
|
||||
config := StaticConfig{Root: "../_fixture"}
|
||||
config := StaticConfig{Root: "_fixture"}
|
||||
if tc.givenConfig != nil {
|
||||
config = *tc.givenConfig
|
||||
}
|
||||
@ -136,14 +181,17 @@ func TestStatic(t *testing.T) {
|
||||
subGroup := e.Group(tc.givenAttachedToGroup, middlewareFunc)
|
||||
// group without http handlers (routes) does not do anything.
|
||||
// Request is matched against http handlers (routes) that have group middleware attached to them
|
||||
subGroup.GET("", echo.NotFoundHandler)
|
||||
subGroup.GET("/*", echo.NotFoundHandler)
|
||||
subGroup.GET("", func(c echo.Context) error { return echo.ErrNotFound })
|
||||
subGroup.GET("/*", func(c echo.Context) error { return echo.ErrNotFound })
|
||||
} else {
|
||||
// middleware is on root level
|
||||
e.Use(middlewareFunc)
|
||||
e.GET("/regular-handler", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
e.GET("/walle.png", func(c echo.Context) error {
|
||||
return c.String(http.StatusTeapot, "walle")
|
||||
})
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
@ -177,7 +225,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
{
|
||||
name: "ok",
|
||||
givenPrefix: "/images",
|
||||
givenRoot: "../_fixture/images",
|
||||
givenRoot: "_fixture/images",
|
||||
whenURL: "/group/images/walle.png",
|
||||
expectStatus: http.StatusOK,
|
||||
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
|
||||
@ -185,7 +233,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
{
|
||||
name: "No file",
|
||||
givenPrefix: "/images",
|
||||
givenRoot: "../_fixture/scripts",
|
||||
givenRoot: "_fixture/scripts",
|
||||
whenURL: "/group/images/bolt.png",
|
||||
expectStatus: http.StatusNotFound,
|
||||
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
@ -193,7 +241,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
{
|
||||
name: "Directory not found (no trailing slash)",
|
||||
givenPrefix: "/images",
|
||||
givenRoot: "../_fixture/images",
|
||||
givenRoot: "_fixture/images",
|
||||
whenURL: "/group/images/",
|
||||
expectStatus: http.StatusNotFound,
|
||||
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
@ -201,7 +249,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
{
|
||||
name: "Directory redirect",
|
||||
givenPrefix: "/",
|
||||
givenRoot: "../_fixture",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/group/folder",
|
||||
expectStatus: http.StatusMovedPermanently,
|
||||
expectHeaderLocation: "/group/folder/",
|
||||
@ -211,7 +259,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
name: "Prefixed directory 404 (request URL without slash)",
|
||||
givenGroup: "_fixture",
|
||||
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
|
||||
givenRoot: "../_fixture",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/_fixture/folder", // no trailing slash
|
||||
expectStatus: http.StatusNotFound,
|
||||
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
@ -220,7 +268,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
name: "Prefixed directory redirect (without slash redirect to slash)",
|
||||
givenGroup: "_fixture",
|
||||
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
|
||||
givenRoot: "../_fixture",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/_fixture/folder", // no trailing slash
|
||||
expectStatus: http.StatusMovedPermanently,
|
||||
expectHeaderLocation: "/_fixture/folder/",
|
||||
@ -229,7 +277,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
{
|
||||
name: "Directory with index.html",
|
||||
givenPrefix: "/",
|
||||
givenRoot: "../_fixture",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/group/",
|
||||
expectStatus: http.StatusOK,
|
||||
expectBodyStartsWith: "<!doctype html>",
|
||||
@ -237,7 +285,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
{
|
||||
name: "Prefixed directory with index.html (prefix ending with slash)",
|
||||
givenPrefix: "/assets/",
|
||||
givenRoot: "../_fixture",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/group/assets/",
|
||||
expectStatus: http.StatusOK,
|
||||
expectBodyStartsWith: "<!doctype html>",
|
||||
@ -245,7 +293,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
{
|
||||
name: "Prefixed directory with index.html (prefix ending without slash)",
|
||||
givenPrefix: "/assets",
|
||||
givenRoot: "../_fixture",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/group/assets/",
|
||||
expectStatus: http.StatusOK,
|
||||
expectBodyStartsWith: "<!doctype html>",
|
||||
@ -253,7 +301,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
{
|
||||
name: "Sub-directory with index.html",
|
||||
givenPrefix: "/",
|
||||
givenRoot: "../_fixture",
|
||||
givenRoot: "_fixture",
|
||||
whenURL: "/group/folder/",
|
||||
expectStatus: http.StatusOK,
|
||||
expectBodyStartsWith: "<!doctype html>",
|
||||
@ -261,7 +309,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
{
|
||||
name: "do not allow directory traversal (backslash - windows separator)",
|
||||
givenPrefix: "/",
|
||||
givenRoot: "../_fixture/",
|
||||
givenRoot: "_fixture/",
|
||||
whenURL: `/group/..\\middleware/basic_auth.go`,
|
||||
expectStatus: http.StatusNotFound,
|
||||
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
@ -269,7 +317,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
{
|
||||
name: "do not allow directory traversal (slash - unix separator)",
|
||||
givenPrefix: "/",
|
||||
givenRoot: "../_fixture/",
|
||||
givenRoot: "_fixture/",
|
||||
whenURL: `/group/../middleware/basic_auth.go`,
|
||||
expectStatus: http.StatusNotFound,
|
||||
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
@ -279,6 +327,8 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Filesystem = os.DirFS("../") // so we can access test files
|
||||
|
||||
group := "/group"
|
||||
if tc.givenGroup != "" {
|
||||
group = tc.givenGroup
|
||||
@ -288,7 +338,9 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, tc.expectStatus, rec.Code)
|
||||
body := rec.Body.String()
|
||||
if tc.expectBodyStartsWith != "" {
|
||||
@ -306,3 +358,73 @@ func TestStatic_GroupWithStatic(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMustStaticWithConfig_panicsInvalidDirListTemplate(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
StaticWithConfig(StaticConfig{DirectoryListTemplate: `{{}`})
|
||||
})
|
||||
}
|
||||
|
||||
func TestFormat(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
when int64
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "byte",
|
||||
when: 0,
|
||||
expect: "0",
|
||||
},
|
||||
{
|
||||
name: "bytes",
|
||||
when: 515,
|
||||
expect: "515B",
|
||||
},
|
||||
{
|
||||
name: "KB",
|
||||
when: 31323,
|
||||
expect: "30.59KB",
|
||||
},
|
||||
{
|
||||
name: "MB",
|
||||
when: 13231323,
|
||||
expect: "12.62MB",
|
||||
},
|
||||
{
|
||||
name: "GB",
|
||||
when: 7323232398,
|
||||
expect: "6.82GB",
|
||||
},
|
||||
{
|
||||
name: "TB",
|
||||
when: 1_099_511_627_776,
|
||||
expect: "1.00TB",
|
||||
},
|
||||
{
|
||||
name: "PB",
|
||||
when: 9923232398434432,
|
||||
expect: "8.81PB",
|
||||
},
|
||||
{
|
||||
// test with 7EB because of https://github.com/labstack/gommon/pull/38 and https://github.com/labstack/gommon/pull/43
|
||||
//
|
||||
// 8 exbi equals 2^64, therefore it cannot be stored in int64. The tests use
|
||||
// the fact that on x86_64 the following expressions holds true:
|
||||
// int64(0) - 1 == math.MaxInt64.
|
||||
//
|
||||
// However, this is not true for other platforms, specifically aarch64, s390x
|
||||
// and ppc64le.
|
||||
name: "EB",
|
||||
when: 8070450532247929000,
|
||||
expect: "7.00EB",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := format(tc.when)
|
||||
assert.Equal(t, tc.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,10 +2,9 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/labstack/echo/v4"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------------
|
||||
@ -55,51 +54,43 @@ import (
|
||||
// })
|
||||
//
|
||||
|
||||
type (
|
||||
// TimeoutConfig defines the config for Timeout middleware.
|
||||
TimeoutConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
// TimeoutConfig defines the config for Timeout middleware.
|
||||
type TimeoutConfig struct {
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
|
||||
// It can be used to define a custom timeout error message
|
||||
ErrorMessage string
|
||||
// ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
|
||||
// It can be used to define a custom timeout error message
|
||||
ErrorMessage string
|
||||
|
||||
// OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
|
||||
// request timeouted and we already had sent the error code (503) and message response to the client.
|
||||
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
|
||||
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
|
||||
OnTimeoutRouteErrorHandler func(err error, c echo.Context)
|
||||
// OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
|
||||
// request timeouted and we already had sent the error code (503) and message response to the client.
|
||||
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
|
||||
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
|
||||
OnTimeoutRouteErrorHandler func(c echo.Context, err error)
|
||||
|
||||
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
|
||||
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
|
||||
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
|
||||
// difference over 500microseconds (0.5millisecond) response seems to be reliable
|
||||
Timeout time.Duration
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultTimeoutConfig is the default Timeout middleware config.
|
||||
DefaultTimeoutConfig = TimeoutConfig{
|
||||
Skipper: DefaultSkipper,
|
||||
Timeout: 0,
|
||||
ErrorMessage: "",
|
||||
}
|
||||
)
|
||||
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
|
||||
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
|
||||
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
|
||||
// difference over 500microseconds (0.5millisecond) response seems to be reliable
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler
|
||||
// call runs for longer than its time limit. NB: timeout does not stop handler execution.
|
||||
func Timeout() echo.MiddlewareFunc {
|
||||
return TimeoutWithConfig(DefaultTimeoutConfig)
|
||||
return TimeoutWithConfig(TimeoutConfig{})
|
||||
}
|
||||
|
||||
// TimeoutWithConfig returns a Timeout middleware with config.
|
||||
// See: `Timeout()`.
|
||||
// TimeoutWithConfig returns a Timeout middleware with config or panics on invalid configuration.
|
||||
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
return toMiddlewareOrPanic(config)
|
||||
}
|
||||
|
||||
// ToMiddleware converts TimeoutConfig to middleware or returns an error for invalid configuration
|
||||
func (config TimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultTimeoutConfig.Skipper
|
||||
config.Skipper = DefaultSkipper
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@ -108,29 +99,30 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
handlerWrapper := echoHandlerFuncWrapper{
|
||||
ctx: c,
|
||||
handler: next,
|
||||
errChan: make(chan error, 1),
|
||||
errChan: errChan,
|
||||
errHandler: config.OnTimeoutRouteErrorHandler,
|
||||
}
|
||||
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
|
||||
handler.ServeHTTP(c.Response().Writer, c.Request())
|
||||
|
||||
select {
|
||||
case err := <-handlerWrapper.errChan:
|
||||
case err := <-errChan:
|
||||
return err
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type echoHandlerFuncWrapper struct {
|
||||
ctx echo.Context
|
||||
handler echo.HandlerFunc
|
||||
errHandler func(err error, c echo.Context)
|
||||
errHandler func(c echo.Context, err error)
|
||||
errChan chan error
|
||||
}
|
||||
|
||||
@ -156,7 +148,7 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques
|
||||
err := t.handler(t.ctx)
|
||||
if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded {
|
||||
if err != nil && t.errHandler != nil {
|
||||
t.errHandler(err, t.ctx)
|
||||
t.errHandler(t.ctx, err)
|
||||
}
|
||||
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
|
||||
}
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
@ -14,9 +16,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTimeoutSkipper(t *testing.T) {
|
||||
@ -111,7 +110,7 @@ func TestTimeoutOnTimeoutRouteErrorHandler(t *testing.T) {
|
||||
actualErrChan := make(chan error, 1)
|
||||
m := TimeoutWithConfig(TimeoutConfig{
|
||||
Timeout: 1 * time.Millisecond,
|
||||
OnTimeoutRouteErrorHandler: func(err error, c echo.Context) {
|
||||
OnTimeoutRouteErrorHandler: func(c echo.Context, err error) {
|
||||
actualErrChan <- err
|
||||
},
|
||||
})
|
||||
@ -360,7 +359,7 @@ func TestTimeoutWithFullEchoStack(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
e.Logger.SetOutput(buf)
|
||||
e.Logger = &testLogger{output: buf}
|
||||
|
||||
// NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first
|
||||
// FIXME: I have no idea how to fix this without adding mutexes.
|
||||
@ -424,7 +423,7 @@ func startServer(e *echo.Echo) (*http.Server, string, error) {
|
||||
|
||||
s := http.Server{
|
||||
Handler: e,
|
||||
ErrorLog: log.New(e.Logger.Output(), "echo:", 0),
|
||||
ErrorLog: log.New(e.Logger, "echo:", 0),
|
||||
}
|
||||
|
||||
errCh := make(chan error)
|
||||
|
@ -1,9 +1,27 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
_ = int64(1 << (10 * iota)) // ignore first value by assigning to blank identifier
|
||||
// KB is 1 KiloByte = 1024 bytes
|
||||
KB
|
||||
// MB is 1 Megabyte = 1_048_576 bytes
|
||||
MB
|
||||
// GB is 1 Gigabyte = 1_073_741_824 bytes
|
||||
GB
|
||||
// TB is 1 Terabyte = 1_099_511_627_776 bytes
|
||||
TB
|
||||
// PB is 1 Petabyte = 1_125_899_906_842_624 bytes
|
||||
PB
|
||||
// EB is 1 Exabyte = 1_152_921_504_606_847_000 bytes
|
||||
EB
|
||||
)
|
||||
|
||||
func matchScheme(domain, pattern string) bool {
|
||||
didx := strings.Index(domain, ":")
|
||||
pidx := strings.Index(pattern, ":")
|
||||
@ -52,3 +70,24 @@ func matchSubdomain(domain, pattern string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func createRandomStringGenerator(length uint8) func() string {
|
||||
return func() string {
|
||||
return randomString(length)
|
||||
}
|
||||
}
|
||||
|
||||
func randomString(length uint8) string {
|
||||
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
bytes := make([]byte, length)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
// we are out of random. let the request fail
|
||||
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
|
||||
}
|
||||
for i, b := range bytes {
|
||||
bytes[i] = charset[b%byte(len(charset))]
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
@ -1,11 +1,23 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testLogger struct {
|
||||
output io.Writer
|
||||
}
|
||||
|
||||
func (l *testLogger) Write(p []byte) (n int, err error) {
|
||||
return l.output.Write(p)
|
||||
}
|
||||
|
||||
func (l *testLogger) Error(err error) {
|
||||
_, _ = l.output.Write([]byte(err.Error()))
|
||||
}
|
||||
|
||||
func Test_matchScheme(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain, pattern string
|
||||
@ -93,3 +105,27 @@ func Test_matchSubdomain(t *testing.T) {
|
||||
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomString(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenLength uint8
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "ok, 16",
|
||||
whenLength: 16,
|
||||
},
|
||||
{
|
||||
name: "ok, 32",
|
||||
whenLength: 32,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
uid := randomString(tc.whenLength)
|
||||
assert.Len(t, uid, int(tc.whenLength))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
31
response.go
31
response.go
@ -2,24 +2,23 @@ package echo
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type (
|
||||
// Response wraps an http.ResponseWriter and implements its interface to be used
|
||||
// by an HTTP handler to construct an HTTP response.
|
||||
// See: https://golang.org/pkg/net/http/#ResponseWriter
|
||||
Response struct {
|
||||
echo *Echo
|
||||
beforeFuncs []func()
|
||||
afterFuncs []func()
|
||||
Writer http.ResponseWriter
|
||||
Status int
|
||||
Size int64
|
||||
Committed bool
|
||||
}
|
||||
)
|
||||
// Response wraps an http.ResponseWriter and implements its interface to be used
|
||||
// by an HTTP handler to construct an HTTP response.
|
||||
// See: https://golang.org/pkg/net/http/#ResponseWriter
|
||||
type Response struct {
|
||||
echo *Echo
|
||||
beforeFuncs []func()
|
||||
afterFuncs []func()
|
||||
Writer http.ResponseWriter
|
||||
Status int
|
||||
Size int64
|
||||
Committed bool
|
||||
}
|
||||
|
||||
// NewResponse creates a new instance of Response.
|
||||
func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) {
|
||||
@ -47,13 +46,15 @@ func (r *Response) After(fn func()) {
|
||||
r.afterFuncs = append(r.afterFuncs, fn)
|
||||
}
|
||||
|
||||
var errHeaderAlreadyCommitted = errors.New("response already committed")
|
||||
|
||||
// WriteHeader sends an HTTP response header with status code. If WriteHeader is
|
||||
// not called explicitly, the first call to Write will trigger an implicit
|
||||
// WriteHeader(http.StatusOK). Thus explicit calls to WriteHeader are mainly
|
||||
// used to send error codes.
|
||||
func (r *Response) WriteHeader(code int) {
|
||||
if r.Committed {
|
||||
r.echo.Logger.Warn("response already committed")
|
||||
r.echo.Logger.Error(errHeaderAlreadyCommitted)
|
||||
return
|
||||
}
|
||||
r.Status = code
|
||||
|
182
route.go
Normal file
182
route.go
Normal file
@ -0,0 +1,182 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// Route contains information to adding/registering new route with the router.
|
||||
// Method+Path pair uniquely identifies the Route. It is mandatory to provide Method+Path+Handler fields.
|
||||
type Route struct {
|
||||
Method string
|
||||
Path string
|
||||
Handler HandlerFunc
|
||||
Middlewares []MiddlewareFunc
|
||||
|
||||
Name string
|
||||
}
|
||||
|
||||
// ToRouteInfo converts Route to RouteInfo
|
||||
func (r Route) ToRouteInfo(params []string) RouteInfo {
|
||||
name := r.Name
|
||||
if name == "" {
|
||||
name = r.Method + ":" + r.Path
|
||||
}
|
||||
|
||||
return routeInfo{
|
||||
method: r.Method,
|
||||
path: r.Path,
|
||||
params: append([]string(nil), params...),
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
// ToRoute returns Route which Router uses to register the method handler for path.
|
||||
func (r Route) ToRoute() Route {
|
||||
return r
|
||||
}
|
||||
|
||||
// ForGroup recreates Route with added group prefix and group middlewares it is grouped to.
|
||||
func (r Route) ForGroup(pathPrefix string, middlewares []MiddlewareFunc) Routable {
|
||||
r.Path = pathPrefix + r.Path
|
||||
|
||||
if len(middlewares) > 0 {
|
||||
m := make([]MiddlewareFunc, 0, len(middlewares)+len(r.Middlewares))
|
||||
m = append(m, middlewares...)
|
||||
m = append(m, r.Middlewares...)
|
||||
r.Middlewares = m
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
type routeInfo struct {
|
||||
method string
|
||||
path string
|
||||
params []string
|
||||
name string
|
||||
}
|
||||
|
||||
func (r routeInfo) Method() string {
|
||||
return r.method
|
||||
}
|
||||
|
||||
func (r routeInfo) Path() string {
|
||||
return r.path
|
||||
}
|
||||
|
||||
func (r routeInfo) Params() []string {
|
||||
return append([]string(nil), r.params...)
|
||||
}
|
||||
|
||||
func (r routeInfo) Name() string {
|
||||
return r.name
|
||||
}
|
||||
|
||||
// Reverse reverses route to URL string by replacing path parameters with given params values.
|
||||
func (r routeInfo) Reverse(params ...interface{}) string {
|
||||
uri := new(bytes.Buffer)
|
||||
ln := len(params)
|
||||
n := 0
|
||||
for i, l := 0, len(r.path); i < l; i++ {
|
||||
if (r.path[i] == paramLabel || r.path[i] == anyLabel) && n < ln {
|
||||
for ; i < l && r.path[i] != '/'; i++ {
|
||||
}
|
||||
uri.WriteString(fmt.Sprintf("%v", params[n]))
|
||||
n++
|
||||
}
|
||||
if i < l {
|
||||
uri.WriteByte(r.path[i])
|
||||
}
|
||||
}
|
||||
return uri.String()
|
||||
}
|
||||
|
||||
// HandlerName returns string name for given function.
|
||||
func HandlerName(h HandlerFunc) string {
|
||||
t := reflect.ValueOf(h).Type()
|
||||
if t.Kind() == reflect.Func {
|
||||
return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
|
||||
}
|
||||
return t.String()
|
||||
}
|
||||
|
||||
// Reverse reverses route to URL string by replacing path parameters with given params values.
|
||||
func (r Routes) Reverse(name string, params ...interface{}) (string, error) {
|
||||
for _, rr := range r {
|
||||
if rr.Name() == name {
|
||||
return rr.Reverse(params...), nil
|
||||
}
|
||||
}
|
||||
return "", errors.New("route not found")
|
||||
}
|
||||
|
||||
// FindByMethodPath searched for matching route info by method and path
|
||||
func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error) {
|
||||
if r == nil {
|
||||
return nil, errors.New("route not found by method and path")
|
||||
}
|
||||
|
||||
for _, rr := range r {
|
||||
if rr.Method() == method && rr.Path() == path {
|
||||
return rr, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("route not found by method and path")
|
||||
}
|
||||
|
||||
// FilterByMethod searched for matching route info by method
|
||||
func (r Routes) FilterByMethod(method string) (Routes, error) {
|
||||
if r == nil {
|
||||
return nil, errors.New("route not found by method")
|
||||
}
|
||||
|
||||
result := make(Routes, 0)
|
||||
for _, rr := range r {
|
||||
if rr.Method() == method {
|
||||
result = append(result, rr)
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, errors.New("route not found by method")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FilterByPath searched for matching route info by path
|
||||
func (r Routes) FilterByPath(path string) (Routes, error) {
|
||||
if r == nil {
|
||||
return nil, errors.New("route not found by path")
|
||||
}
|
||||
|
||||
result := make(Routes, 0)
|
||||
for _, rr := range r {
|
||||
if rr.Path() == path {
|
||||
result = append(result, rr)
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, errors.New("route not found by path")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FilterByName searched for matching route info by name
|
||||
func (r Routes) FilterByName(name string) (Routes, error) {
|
||||
if r == nil {
|
||||
return nil, errors.New("route not found by name")
|
||||
}
|
||||
|
||||
result := make(Routes, 0)
|
||||
for _, rr := range r {
|
||||
if rr.Name() == name {
|
||||
result = append(result, rr)
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, errors.New("route not found by name")
|
||||
}
|
||||
return result, nil
|
||||
}
|
423
route_test.go
Normal file
423
route_test.go
Normal file
@ -0,0 +1,423 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var myNamedHandler = func(c Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type NameStruct struct {
|
||||
}
|
||||
|
||||
func (n *NameStruct) getUsers(c Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHandlerName(t *testing.T) {
|
||||
myNameFuncVar := func(c Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tmp := NameStruct{}
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenHandlerFunc HandlerFunc
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "ok, func as anonymous func",
|
||||
whenHandlerFunc: func(c Context) error {
|
||||
return nil
|
||||
},
|
||||
expect: "github.com/labstack/echo/v4.TestHandlerName.func2",
|
||||
},
|
||||
{
|
||||
name: "ok, func as named package variable",
|
||||
whenHandlerFunc: myNamedHandler,
|
||||
expect: "github.com/labstack/echo/v4.glob..func3",
|
||||
},
|
||||
{
|
||||
name: "ok, func as named function variable",
|
||||
whenHandlerFunc: myNameFuncVar,
|
||||
expect: "github.com/labstack/echo/v4.TestHandlerName.func1",
|
||||
},
|
||||
{
|
||||
name: "ok, func as struct method",
|
||||
whenHandlerFunc: tmp.getUsers,
|
||||
expect: "github.com/labstack/echo/v4.(*NameStruct).getUsers-fm",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
name := HandlerName(tc.whenHandlerFunc)
|
||||
assert.Equal(t, tc.expect, name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerName_differentFuncSameName(t *testing.T) {
|
||||
handlerCreator := func(name string) HandlerFunc {
|
||||
return func(c Context) error {
|
||||
return c.String(http.StatusTeapot, name)
|
||||
}
|
||||
}
|
||||
h1 := handlerCreator("name1")
|
||||
assert.Equal(t, "github.com/labstack/echo/v4.TestHandlerName_differentFuncSameName.func2", HandlerName(h1))
|
||||
|
||||
h2 := handlerCreator("name2")
|
||||
assert.Equal(t, "github.com/labstack/echo/v4.TestHandlerName_differentFuncSameName.func3", HandlerName(h2))
|
||||
}
|
||||
|
||||
func TestRoute_ToRouteInfo(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
given Route
|
||||
whenParams []string
|
||||
expect RouteInfo
|
||||
}{
|
||||
{
|
||||
name: "ok, no params, with name",
|
||||
given: Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/test",
|
||||
Handler: func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
},
|
||||
Middlewares: nil,
|
||||
Name: "test route",
|
||||
},
|
||||
expect: routeInfo{
|
||||
method: http.MethodGet,
|
||||
path: "/test",
|
||||
params: nil,
|
||||
name: "test route",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, params",
|
||||
given: Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "users/:id/:file", // no slash prefix
|
||||
Handler: func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
},
|
||||
Middlewares: nil,
|
||||
Name: "",
|
||||
},
|
||||
whenParams: []string{"id", "file"},
|
||||
expect: routeInfo{
|
||||
method: http.MethodGet,
|
||||
path: "users/:id/:file",
|
||||
params: []string{"id", "file"},
|
||||
name: "GET:users/:id/:file",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ri := tc.given.ToRouteInfo(tc.whenParams)
|
||||
assert.Equal(t, tc.expect, ri)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoute_ToRoute(t *testing.T) {
|
||||
route := Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/test",
|
||||
Handler: func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
},
|
||||
Middlewares: nil,
|
||||
Name: "test route",
|
||||
}
|
||||
|
||||
r := route.ToRoute()
|
||||
assert.Equal(t, r.Method, http.MethodGet)
|
||||
assert.Equal(t, r.Path, "/test")
|
||||
assert.NotNil(t, r.Handler)
|
||||
assert.Nil(t, r.Middlewares)
|
||||
assert.Equal(t, r.Name, "test route")
|
||||
}
|
||||
|
||||
func TestRoute_ForGroup(t *testing.T) {
|
||||
route := Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/test",
|
||||
Handler: func(c Context) error {
|
||||
return c.String(http.StatusTeapot, "OK")
|
||||
},
|
||||
Middlewares: nil,
|
||||
Name: "test route",
|
||||
}
|
||||
|
||||
mw := func(next HandlerFunc) HandlerFunc {
|
||||
return func(c Context) error {
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
gr := route.ForGroup("/users", []MiddlewareFunc{mw})
|
||||
|
||||
r := gr.ToRoute()
|
||||
assert.Equal(t, r.Method, http.MethodGet)
|
||||
assert.Equal(t, r.Path, "/users/test")
|
||||
assert.NotNil(t, r.Handler)
|
||||
assert.Len(t, r.Middlewares, 1)
|
||||
assert.Equal(t, r.Name, "test route")
|
||||
}
|
||||
|
||||
func exampleRoutes() Routes {
|
||||
return Routes{
|
||||
routeInfo{
|
||||
method: http.MethodGet,
|
||||
path: "/users",
|
||||
params: nil,
|
||||
name: "GET:/users",
|
||||
},
|
||||
routeInfo{
|
||||
method: http.MethodGet,
|
||||
path: "/users/:id",
|
||||
params: []string{"id"},
|
||||
name: "GET:/users/:id",
|
||||
},
|
||||
routeInfo{
|
||||
method: http.MethodPost,
|
||||
path: "/users/:id",
|
||||
params: []string{"id"},
|
||||
name: "POST:/users/:id",
|
||||
},
|
||||
routeInfo{
|
||||
method: http.MethodDelete,
|
||||
path: "/groups",
|
||||
params: nil,
|
||||
name: "non_unique_name",
|
||||
},
|
||||
routeInfo{
|
||||
method: http.MethodPost,
|
||||
path: "/groups",
|
||||
params: nil,
|
||||
name: "non_unique_name",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutes_FindByMethodPath(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
given Routes
|
||||
whenMethod string
|
||||
whenPath string
|
||||
expectName string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, found",
|
||||
given: exampleRoutes(),
|
||||
whenMethod: http.MethodGet,
|
||||
whenPath: "/users/:id",
|
||||
expectName: "GET:/users/:id",
|
||||
},
|
||||
{
|
||||
name: "nok, not found",
|
||||
given: exampleRoutes(),
|
||||
whenMethod: http.MethodPut,
|
||||
whenPath: "/users/:id",
|
||||
expectName: "",
|
||||
expectError: "route not found by method and path",
|
||||
},
|
||||
{
|
||||
name: "nok, not found from nil",
|
||||
given: nil,
|
||||
whenMethod: http.MethodGet,
|
||||
whenPath: "/users/:id",
|
||||
expectName: "",
|
||||
expectError: "route not found by method and path",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ri, err := tc.given.FindByMethodPath(tc.whenMethod, tc.whenPath)
|
||||
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
assert.Nil(t, ri)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tc.expectName != "" {
|
||||
assert.Equal(t, tc.expectName, ri.Name())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutes_FilterByMethod(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
given Routes
|
||||
whenMethod string
|
||||
expectNames []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, found",
|
||||
given: exampleRoutes(),
|
||||
whenMethod: http.MethodGet,
|
||||
expectNames: []string{"GET:/users", "GET:/users/:id"},
|
||||
},
|
||||
{
|
||||
name: "nok, not found",
|
||||
given: exampleRoutes(),
|
||||
whenMethod: http.MethodPut,
|
||||
expectNames: nil,
|
||||
expectError: "route not found by method",
|
||||
},
|
||||
{
|
||||
name: "nok, not found from nil",
|
||||
given: nil,
|
||||
whenMethod: http.MethodGet,
|
||||
expectNames: nil,
|
||||
expectError: "route not found by method",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ris, err := tc.given.FilterByMethod(tc.whenMethod)
|
||||
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
if len(tc.expectNames) > 0 {
|
||||
assert.Len(t, ris, len(tc.expectNames))
|
||||
for _, ri := range ris {
|
||||
assert.Contains(t, tc.expectNames, ri.Name())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, ris)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutes_FilterByPath(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
given Routes
|
||||
whenPath string
|
||||
expectNames []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, found",
|
||||
given: exampleRoutes(),
|
||||
whenPath: "/users/:id",
|
||||
expectNames: []string{"GET:/users/:id", "POST:/users/:id"},
|
||||
},
|
||||
{
|
||||
name: "nok, not found",
|
||||
given: exampleRoutes(),
|
||||
whenPath: "/",
|
||||
expectNames: nil,
|
||||
expectError: "route not found by path",
|
||||
},
|
||||
{
|
||||
name: "nok, not found from nil",
|
||||
given: nil,
|
||||
whenPath: "/users/:id",
|
||||
expectNames: nil,
|
||||
expectError: "route not found by path",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ris, err := tc.given.FilterByPath(tc.whenPath)
|
||||
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
if len(tc.expectNames) > 0 {
|
||||
assert.Len(t, ris, len(tc.expectNames))
|
||||
for _, ri := range ris {
|
||||
assert.Contains(t, tc.expectNames, ri.Name())
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, ris)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutes_FilterByName(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
given Routes
|
||||
whenName string
|
||||
expectMethodPath []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, found multiple",
|
||||
given: exampleRoutes(),
|
||||
whenName: "non_unique_name",
|
||||
expectMethodPath: []string{"DELETE:/groups", "POST:/groups"},
|
||||
},
|
||||
{
|
||||
name: "ok, found single",
|
||||
given: exampleRoutes(),
|
||||
whenName: "GET:/users/:id",
|
||||
expectMethodPath: []string{"GET:/users/:id"},
|
||||
},
|
||||
{
|
||||
name: "nok, not found",
|
||||
given: exampleRoutes(),
|
||||
whenName: "/",
|
||||
expectMethodPath: nil,
|
||||
expectError: "route not found by name",
|
||||
},
|
||||
{
|
||||
name: "nok, not found from nil",
|
||||
given: nil,
|
||||
whenName: "/users/:id",
|
||||
expectMethodPath: nil,
|
||||
expectError: "route not found by name",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ris, err := tc.given.FilterByName(tc.whenName)
|
||||
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
if len(tc.expectMethodPath) > 0 {
|
||||
assert.Len(t, ris, len(tc.expectMethodPath))
|
||||
for _, ri := range ris {
|
||||
assert.Contains(t, tc.expectMethodPath, fmt.Sprintf("%v:%v", ri.Method(), ri.Path()))
|
||||
}
|
||||
} else {
|
||||
assert.Nil(t, ris)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
783
router.go
783
router.go
@ -1,50 +1,134 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type (
|
||||
// Router is the registry of all registered routes for an `Echo` instance for
|
||||
// request matching and URL path parameter parsing.
|
||||
Router struct {
|
||||
tree *node
|
||||
routes map[string]*Route
|
||||
echo *Echo
|
||||
}
|
||||
node struct {
|
||||
kind kind
|
||||
label byte
|
||||
prefix string
|
||||
parent *node
|
||||
staticChildren children
|
||||
ppath string
|
||||
pnames []string
|
||||
methodHandler *methodHandler
|
||||
paramChild *node
|
||||
anyChild *node
|
||||
// isLeaf indicates that node does not have child routes
|
||||
isLeaf bool
|
||||
// isHandler indicates that node has at least one handler registered to it
|
||||
isHandler bool
|
||||
}
|
||||
kind uint8
|
||||
children []*node
|
||||
methodHandler struct {
|
||||
connect HandlerFunc
|
||||
delete HandlerFunc
|
||||
get HandlerFunc
|
||||
head HandlerFunc
|
||||
options HandlerFunc
|
||||
patch HandlerFunc
|
||||
post HandlerFunc
|
||||
propfind HandlerFunc
|
||||
put HandlerFunc
|
||||
trace HandlerFunc
|
||||
report HandlerFunc
|
||||
}
|
||||
// Router is interface for routing requests to registered routes.
|
||||
type Router interface {
|
||||
// Add registers Routable with the Router and returns registered RouteInfo
|
||||
Add(routable Routable) (RouteInfo, error)
|
||||
// Remove removes route from the Router
|
||||
Remove(method string, path string) error
|
||||
// Routes returns information about all registered routes
|
||||
Routes() Routes
|
||||
|
||||
// Match searches Router for matching route and applies it to result fields.
|
||||
Match(req *http.Request, params *PathParams) RouteMatch
|
||||
}
|
||||
|
||||
// Routable is interface for registering Route with Router. During route registration process the Router will
|
||||
// convert Routable to RouteInfo with ToRouteInfo method. By creating custom implementation of Routable additional
|
||||
// information about registered route can be stored in Routes (i.e. privileges used with route etc.)
|
||||
type Routable interface {
|
||||
// ToRouteInfo converts Routable to RouteInfo
|
||||
//
|
||||
// This method is meant to be used by Router after it parses url for path parameters, to store information about
|
||||
// route just added.
|
||||
ToRouteInfo(params []string) RouteInfo
|
||||
// ToRoute converts Routable to Route which Router uses to register the method handler for path.
|
||||
//
|
||||
// This method is meant to be used by Router to get fields (including handler and middleware functions) needed to
|
||||
// add Route to Router.
|
||||
ToRoute() Route
|
||||
// ForGroup recreates routable with added group prefix and group middlewares it is grouped to.
|
||||
//
|
||||
// Is necessary for Echo.Group to be able to add/register Routable with Router and having group prefix and group
|
||||
// middlewares included in actually registered Route.
|
||||
ForGroup(pathPrefix string, middlewares []MiddlewareFunc) Routable
|
||||
}
|
||||
|
||||
// Routes is collection of RouteInfo instances with various helper methods.
|
||||
type Routes []RouteInfo
|
||||
|
||||
// RouteInfo describes registered route base fields.
|
||||
// Method+Path pair uniquely identifies the Route. Name can have duplicates.
|
||||
type RouteInfo interface {
|
||||
Method() string
|
||||
Path() string
|
||||
Name() string
|
||||
|
||||
Params() []string
|
||||
Reverse(params ...interface{}) string
|
||||
|
||||
// NOTE: handler and middlewares are not exposed because handler could be already wrapping middlewares and therefore
|
||||
// it is not always 100% known if handler function already wraps middlewares or not. In Echo handler could be one
|
||||
// function or several functions wrapping each other.
|
||||
}
|
||||
|
||||
// RouteMatchType describes possible states that request could be in perspective of routing
|
||||
type RouteMatchType uint8
|
||||
|
||||
const (
|
||||
// RouteMatchUnknown is state before routing is done. Default state for fresh context.
|
||||
RouteMatchUnknown RouteMatchType = iota
|
||||
// RouteMatchNotFound is state when router did not find matching route for current request
|
||||
RouteMatchNotFound
|
||||
// RouteMatchMethodNotAllowed is state when router did not find route with matching path + method for current request.
|
||||
// Although router had matching route with that path but different method.
|
||||
RouteMatchMethodNotAllowed
|
||||
// RouteMatchFound is state when router found exact match for path + method combination
|
||||
RouteMatchFound
|
||||
)
|
||||
|
||||
// RouteMatch is result object for Router.Match. Its main purpose is to avoid allocating memory for PathParams inside router.
|
||||
type RouteMatch struct {
|
||||
// Type contains result as enumeration of Router.Match and helps to understand did Router actually matched Route or
|
||||
// what kind of error case (404/405) we have at the end of the handler chain.
|
||||
Type RouteMatchType
|
||||
// RoutePath contains original path with what matched route was registered with (including placeholders etc).
|
||||
RoutePath string
|
||||
// Handler is function(chain) that was matched by router. In case of no match could result to ErrNotFound or ErrMethodNotAllowed.
|
||||
Handler HandlerFunc
|
||||
// RouteInfo is information about route we just matched
|
||||
RouteInfo RouteInfo
|
||||
}
|
||||
|
||||
// PathParams is collections of PathParam instances with various helper methods
|
||||
type PathParams []PathParam
|
||||
|
||||
// PathParam is tuple pf path parameter name and its value in request path
|
||||
type PathParam struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
// DefaultRouter is the registry of all registered routes for an `Echo` instance for
|
||||
// request matching and URL path parameter parsing.
|
||||
// Note: DefaultRouter is not coroutine-safe. Do not Add/Remove routes after HTTP server has been started with Echo.
|
||||
type DefaultRouter struct {
|
||||
tree *node
|
||||
routes Routes
|
||||
echo *Echo
|
||||
|
||||
allowOverwritingRoute bool
|
||||
unescapePathParamValues bool
|
||||
useEscapedPathForRouting bool
|
||||
}
|
||||
|
||||
type children []*node
|
||||
|
||||
type node struct {
|
||||
kind kind
|
||||
label byte
|
||||
prefix string
|
||||
parent *node
|
||||
staticChildren children
|
||||
originalPath string
|
||||
methods *routeMethods
|
||||
paramChild *node
|
||||
anyChild *node
|
||||
paramsCount int
|
||||
// isLeaf indicates that node does not have child routes
|
||||
isLeaf bool
|
||||
// isHandler indicates that node has at least one handler registered to it
|
||||
isHandler bool
|
||||
}
|
||||
|
||||
type kind uint8
|
||||
|
||||
const (
|
||||
staticKind kind = iota
|
||||
paramKind
|
||||
@ -54,90 +138,362 @@ const (
|
||||
anyLabel = byte('*')
|
||||
)
|
||||
|
||||
func (m *methodHandler) isHandler() bool {
|
||||
return m.connect != nil ||
|
||||
m.delete != nil ||
|
||||
m.get != nil ||
|
||||
m.head != nil ||
|
||||
m.options != nil ||
|
||||
m.patch != nil ||
|
||||
m.post != nil ||
|
||||
m.propfind != nil ||
|
||||
m.put != nil ||
|
||||
m.trace != nil ||
|
||||
m.report != nil
|
||||
type routeMethod struct {
|
||||
*routeInfo
|
||||
handler HandlerFunc
|
||||
orgRouteInfo RouteInfo
|
||||
}
|
||||
|
||||
// NewRouter returns a new Router instance.
|
||||
func NewRouter(e *Echo) *Router {
|
||||
return &Router{
|
||||
tree: &node{
|
||||
methodHandler: new(methodHandler),
|
||||
},
|
||||
routes: map[string]*Route{},
|
||||
echo: e,
|
||||
type routeMethods struct {
|
||||
connect *routeMethod
|
||||
delete *routeMethod
|
||||
get *routeMethod
|
||||
head *routeMethod
|
||||
options *routeMethod
|
||||
patch *routeMethod
|
||||
post *routeMethod
|
||||
propfind *routeMethod
|
||||
put *routeMethod
|
||||
trace *routeMethod
|
||||
report *routeMethod
|
||||
anyOther map[string]*routeMethod
|
||||
}
|
||||
|
||||
func (m *routeMethods) set(method string, r *routeMethod) {
|
||||
switch method {
|
||||
case http.MethodConnect:
|
||||
m.connect = r
|
||||
case http.MethodDelete:
|
||||
m.delete = r
|
||||
case http.MethodGet:
|
||||
m.get = r
|
||||
case http.MethodHead:
|
||||
m.head = r
|
||||
case http.MethodOptions:
|
||||
m.options = r
|
||||
case http.MethodPatch:
|
||||
m.patch = r
|
||||
case http.MethodPost:
|
||||
m.post = r
|
||||
case PROPFIND:
|
||||
m.propfind = r
|
||||
case http.MethodPut:
|
||||
m.put = r
|
||||
case http.MethodTrace:
|
||||
m.trace = r
|
||||
case REPORT:
|
||||
m.report = r
|
||||
default:
|
||||
if m.anyOther == nil {
|
||||
m.anyOther = make(map[string]*routeMethod)
|
||||
}
|
||||
if r.handler == nil {
|
||||
delete(m.anyOther, method)
|
||||
} else {
|
||||
m.anyOther[method] = r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add registers a new route for method and path with matching handler.
|
||||
func (r *Router) Add(method, path string, h HandlerFunc) {
|
||||
// Validate path
|
||||
func (m *routeMethods) find(method string) *routeMethod {
|
||||
switch method {
|
||||
case http.MethodConnect:
|
||||
return m.connect
|
||||
case http.MethodDelete:
|
||||
return m.delete
|
||||
case http.MethodGet:
|
||||
return m.get
|
||||
case http.MethodHead:
|
||||
return m.head
|
||||
case http.MethodOptions:
|
||||
return m.options
|
||||
case http.MethodPatch:
|
||||
return m.patch
|
||||
case http.MethodPost:
|
||||
return m.post
|
||||
case PROPFIND:
|
||||
return m.propfind
|
||||
case http.MethodPut:
|
||||
return m.put
|
||||
case http.MethodTrace:
|
||||
return m.trace
|
||||
case REPORT:
|
||||
return m.report
|
||||
default:
|
||||
return m.anyOther[method]
|
||||
}
|
||||
}
|
||||
|
||||
func (m *routeMethods) isHandler() bool {
|
||||
return m.get != nil ||
|
||||
m.post != nil ||
|
||||
m.options != nil ||
|
||||
m.put != nil ||
|
||||
m.delete != nil ||
|
||||
m.connect != nil ||
|
||||
m.head != nil ||
|
||||
m.patch != nil ||
|
||||
m.propfind != nil ||
|
||||
m.trace != nil ||
|
||||
m.report != nil ||
|
||||
len(m.anyOther) != 0
|
||||
}
|
||||
|
||||
// RouterConfig is configuration options for (default) router
|
||||
type RouterConfig struct {
|
||||
// AllowOverwritingRoute instructs Router NOT to return error when new route is registered with the same method+path
|
||||
// and replaces matching route with the new one.
|
||||
AllowOverwritingRoute bool
|
||||
// UnescapePathParamValues instructs Router to unescape path parameter value when request if matched to the routes
|
||||
UnescapePathParamValues bool
|
||||
// UseEscapedPathForMatching instructs Router to use escaped request URL path (req.URL.Path) for matching the request.
|
||||
UseEscapedPathForMatching bool
|
||||
}
|
||||
|
||||
// NewRouter returns a new Router instance.
|
||||
func NewRouter(e *Echo, config RouterConfig) *DefaultRouter {
|
||||
r := &DefaultRouter{
|
||||
tree: &node{
|
||||
methods: new(routeMethods),
|
||||
isLeaf: true,
|
||||
isHandler: false,
|
||||
},
|
||||
routes: make(Routes, 0),
|
||||
echo: e,
|
||||
|
||||
allowOverwritingRoute: config.AllowOverwritingRoute,
|
||||
unescapePathParamValues: config.UnescapePathParamValues,
|
||||
useEscapedPathForRouting: config.UseEscapedPathForMatching,
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Routes returns all registered routes
|
||||
func (r *DefaultRouter) Routes() Routes {
|
||||
return r.routes
|
||||
}
|
||||
|
||||
// Remove unregisters registered route
|
||||
func (r *DefaultRouter) Remove(method string, path string) error {
|
||||
currentNode := r.tree
|
||||
if currentNode == nil || (currentNode.isLeaf && !currentNode.isHandler) {
|
||||
return errors.New("router has no routes to remove")
|
||||
}
|
||||
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
if path[0] != '/' {
|
||||
path = "/" + path
|
||||
}
|
||||
pnames := []string{} // Param names
|
||||
ppath := path // Pristine path
|
||||
|
||||
if h == nil && r.echo.Logger != nil {
|
||||
// FIXME: in future we should return error
|
||||
r.echo.Logger.Errorf("Adding route without handler function: %v:%v", method, path)
|
||||
var nodeToRemove *node
|
||||
prefixLen := 0
|
||||
for {
|
||||
if currentNode.originalPath == path && currentNode.isHandler {
|
||||
nodeToRemove = currentNode
|
||||
break
|
||||
}
|
||||
if currentNode.kind == staticKind {
|
||||
prefixLen = prefixLen + len(currentNode.prefix)
|
||||
} else {
|
||||
prefixLen = len(currentNode.originalPath)
|
||||
}
|
||||
|
||||
if prefixLen >= len(path) {
|
||||
break
|
||||
}
|
||||
|
||||
next := path[prefixLen]
|
||||
switch next {
|
||||
case paramLabel:
|
||||
currentNode = currentNode.paramChild
|
||||
case anyLabel:
|
||||
currentNode = currentNode.anyChild
|
||||
default:
|
||||
currentNode = currentNode.findStaticChild(next)
|
||||
}
|
||||
|
||||
if currentNode == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if nodeToRemove == nil {
|
||||
return errors.New("could not find route to remove by given path")
|
||||
}
|
||||
|
||||
if !nodeToRemove.isHandler {
|
||||
return errors.New("could not find route to remove by given path")
|
||||
}
|
||||
|
||||
if mh := nodeToRemove.methods.find(method); mh == nil {
|
||||
return errors.New("could not find route to remove by given path and method")
|
||||
}
|
||||
nodeToRemove.setHandler(method, nil)
|
||||
|
||||
var rIndex int
|
||||
for i, rr := range r.routes {
|
||||
if rr.Method() == method && rr.Path() == path {
|
||||
rIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
r.routes = append(r.routes[:rIndex], r.routes[rIndex+1:]...)
|
||||
|
||||
if !nodeToRemove.isHandler && nodeToRemove.isLeaf {
|
||||
// TODO: if !nodeToRemove.isLeaf and has at least 2 children merge paths for remaining nodes?
|
||||
current := nodeToRemove
|
||||
for {
|
||||
parent := current.parent
|
||||
if parent == nil {
|
||||
break
|
||||
}
|
||||
switch current.kind {
|
||||
case staticKind:
|
||||
var index int
|
||||
for i, c := range parent.staticChildren {
|
||||
if c == current {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
parent.staticChildren = append(parent.staticChildren[:index], parent.staticChildren[index+1:]...)
|
||||
case paramKind:
|
||||
parent.paramChild = nil
|
||||
case anyKind:
|
||||
parent.anyChild = nil
|
||||
}
|
||||
|
||||
parent.isLeaf = parent.anyChild == nil && parent.paramChild == nil && len(parent.staticChildren) == 0
|
||||
if !parent.isLeaf || parent.isHandler {
|
||||
break
|
||||
}
|
||||
current = parent
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRouteError is error returned by Router.Add containing information what actual route adding failed. Useful for
|
||||
// mass adding (i.e. Any() routes)
|
||||
type AddRouteError struct {
|
||||
Method string
|
||||
Path string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *AddRouteError) Error() string { return e.Method + " " + e.Path + ": " + e.Err.Error() }
|
||||
|
||||
func (e *AddRouteError) Unwrap() error { return e.Err }
|
||||
|
||||
func newAddRouteError(route Route, err error) *AddRouteError {
|
||||
return &AddRouteError{
|
||||
Method: route.Method,
|
||||
Path: route.Path,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Add registers a new route for method and path with matching handler.
|
||||
func (r *DefaultRouter) Add(routable Routable) (RouteInfo, error) {
|
||||
route := routable.ToRoute()
|
||||
if route.Handler == nil {
|
||||
return nil, newAddRouteError(route, errors.New("adding route without handler function"))
|
||||
}
|
||||
method := route.Method
|
||||
path := route.Path
|
||||
h := applyMiddleware(route.Handler, route.Middlewares...)
|
||||
if !r.allowOverwritingRoute {
|
||||
for _, rr := range r.routes {
|
||||
if route.Method == rr.Method() && route.Path == rr.Path() {
|
||||
return nil, newAddRouteError(route, errors.New("adding duplicate route (same method+path) is not allowed"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
if path[0] != '/' {
|
||||
path = "/" + path
|
||||
}
|
||||
paramNames := make([]string, 0)
|
||||
originalPath := path
|
||||
wasAdded := false
|
||||
var ri RouteInfo
|
||||
for i, lcpIndex := 0, len(path); i < lcpIndex; i++ {
|
||||
if path[i] == ':' {
|
||||
if path[i] == paramLabel {
|
||||
if i > 0 && path[i-1] == '\\' {
|
||||
continue
|
||||
}
|
||||
j := i + 1
|
||||
|
||||
r.insert(method, path[:i], nil, staticKind, "", nil)
|
||||
r.insert(staticKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
|
||||
for ; i < lcpIndex && path[i] != '/'; i++ {
|
||||
}
|
||||
|
||||
pnames = append(pnames, path[j:i])
|
||||
paramNames = append(paramNames, path[j:i])
|
||||
path = path[:j] + path[i:]
|
||||
i, lcpIndex = j, len(path)
|
||||
|
||||
if i == lcpIndex {
|
||||
// path node is last fragment of route path. ie. `/users/:id`
|
||||
r.insert(method, path[:i], h, paramKind, ppath, pnames)
|
||||
ri = routable.ToRouteInfo(paramNames)
|
||||
rm := routeMethod{
|
||||
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
|
||||
handler: h,
|
||||
orgRouteInfo: ri,
|
||||
}
|
||||
r.insert(paramKind, path[:i], method, rm)
|
||||
wasAdded = true
|
||||
break
|
||||
} else {
|
||||
r.insert(method, path[:i], nil, paramKind, "", nil)
|
||||
r.insert(paramKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
|
||||
}
|
||||
} else if path[i] == '*' {
|
||||
r.insert(method, path[:i], nil, staticKind, "", nil)
|
||||
pnames = append(pnames, "*")
|
||||
r.insert(method, path[:i+1], h, anyKind, ppath, pnames)
|
||||
} else if path[i] == anyLabel {
|
||||
r.insert(staticKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
|
||||
paramNames = append(paramNames, "*")
|
||||
ri = routable.ToRouteInfo(paramNames)
|
||||
rm := routeMethod{
|
||||
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
|
||||
handler: h,
|
||||
orgRouteInfo: ri,
|
||||
}
|
||||
r.insert(anyKind, path[:i+1], method, rm)
|
||||
wasAdded = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
r.insert(method, path, h, staticKind, ppath, pnames)
|
||||
if !wasAdded {
|
||||
ri = routable.ToRouteInfo(paramNames)
|
||||
rm := routeMethod{
|
||||
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
|
||||
handler: h,
|
||||
orgRouteInfo: ri,
|
||||
}
|
||||
r.insert(staticKind, path, method, rm)
|
||||
}
|
||||
|
||||
r.storeRouteInfo(ri)
|
||||
|
||||
return ri, nil
|
||||
}
|
||||
|
||||
func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string) {
|
||||
// Adjust max param
|
||||
paramLen := len(pnames)
|
||||
if *r.echo.maxParam < paramLen {
|
||||
*r.echo.maxParam = paramLen
|
||||
func (r *DefaultRouter) storeRouteInfo(ri RouteInfo) {
|
||||
for i, rr := range r.routes {
|
||||
if ri.Method() == rr.Method() && ri.Path() == rr.Path() {
|
||||
r.routes[i] = ri
|
||||
return
|
||||
}
|
||||
}
|
||||
r.routes = append(r.routes, ri)
|
||||
}
|
||||
|
||||
func (r *DefaultRouter) insert(t kind, path string, method string, ri routeMethod) {
|
||||
currentNode := r.tree // Current node as root
|
||||
if currentNode == nil {
|
||||
panic("echo: invalid method")
|
||||
}
|
||||
search := path
|
||||
|
||||
for {
|
||||
@ -157,11 +513,11 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
|
||||
// At root node
|
||||
currentNode.label = search[0]
|
||||
currentNode.prefix = search
|
||||
if h != nil {
|
||||
if ri.handler != nil {
|
||||
currentNode.kind = t
|
||||
currentNode.addHandler(method, h)
|
||||
currentNode.ppath = ppath
|
||||
currentNode.pnames = pnames
|
||||
currentNode.setHandler(method, &ri)
|
||||
currentNode.paramsCount = len(ri.params)
|
||||
currentNode.originalPath = ri.path
|
||||
}
|
||||
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
|
||||
} else if lcpLen < prefixLen {
|
||||
@ -171,9 +527,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
|
||||
currentNode.prefix[lcpLen:],
|
||||
currentNode,
|
||||
currentNode.staticChildren,
|
||||
currentNode.methodHandler,
|
||||
currentNode.ppath,
|
||||
currentNode.pnames,
|
||||
currentNode.methods,
|
||||
currentNode.paramsCount,
|
||||
currentNode.originalPath,
|
||||
currentNode.paramChild,
|
||||
currentNode.anyChild,
|
||||
)
|
||||
@ -193,9 +549,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
|
||||
currentNode.label = currentNode.prefix[0]
|
||||
currentNode.prefix = currentNode.prefix[:lcpLen]
|
||||
currentNode.staticChildren = nil
|
||||
currentNode.methodHandler = new(methodHandler)
|
||||
currentNode.ppath = ""
|
||||
currentNode.pnames = nil
|
||||
currentNode.methods = new(routeMethods)
|
||||
currentNode.originalPath = ""
|
||||
currentNode.paramsCount = 0
|
||||
currentNode.paramChild = nil
|
||||
currentNode.anyChild = nil
|
||||
currentNode.isLeaf = false
|
||||
@ -207,13 +563,18 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
|
||||
if lcpLen == searchLen {
|
||||
// At parent node
|
||||
currentNode.kind = t
|
||||
currentNode.addHandler(method, h)
|
||||
currentNode.ppath = ppath
|
||||
currentNode.pnames = pnames
|
||||
if ri.handler != nil {
|
||||
currentNode.setHandler(method, &ri)
|
||||
currentNode.paramsCount = len(ri.params)
|
||||
currentNode.originalPath = ri.path
|
||||
}
|
||||
} else {
|
||||
// Create child node
|
||||
n = newNode(t, search[lcpLen:], currentNode, nil, new(methodHandler), ppath, pnames, nil, nil)
|
||||
n.addHandler(method, h)
|
||||
n = newNode(t, search[lcpLen:], currentNode, nil, new(routeMethods), 0, ri.path, nil, nil)
|
||||
if ri.handler != nil {
|
||||
n.setHandler(method, &ri)
|
||||
n.paramsCount = len(ri.params)
|
||||
}
|
||||
// Only Static children could reach here
|
||||
currentNode.addStaticChild(n)
|
||||
}
|
||||
@ -227,8 +588,11 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
|
||||
continue
|
||||
}
|
||||
// Create child node
|
||||
n := newNode(t, search, currentNode, nil, new(methodHandler), ppath, pnames, nil, nil)
|
||||
n.addHandler(method, h)
|
||||
n := newNode(t, search, currentNode, nil, new(routeMethods), 0, ri.path, nil, nil)
|
||||
if ri.handler != nil {
|
||||
n.setHandler(method, &ri)
|
||||
n.paramsCount = len(ri.params)
|
||||
}
|
||||
switch t {
|
||||
case staticKind:
|
||||
currentNode.addStaticChild(n)
|
||||
@ -240,28 +604,26 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
|
||||
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
|
||||
} else {
|
||||
// Node already exists
|
||||
if h != nil {
|
||||
currentNode.addHandler(method, h)
|
||||
currentNode.ppath = ppath
|
||||
if len(currentNode.pnames) == 0 { // Issue #729
|
||||
currentNode.pnames = pnames
|
||||
}
|
||||
if ri.handler != nil {
|
||||
currentNode.setHandler(method, &ri)
|
||||
currentNode.paramsCount = len(ri.params)
|
||||
currentNode.originalPath = ri.path
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath string, pnames []string, paramChildren, anyChildren *node) *node {
|
||||
func newNode(t kind, pre string, p *node, sc children, mh *routeMethods, paramsCount int, ppath string, paramChildren, anyChildren *node) *node {
|
||||
return &node{
|
||||
kind: t,
|
||||
label: pre[0],
|
||||
prefix: pre,
|
||||
parent: p,
|
||||
staticChildren: sc,
|
||||
ppath: ppath,
|
||||
pnames: pnames,
|
||||
methodHandler: mh,
|
||||
originalPath: ppath,
|
||||
paramsCount: paramsCount,
|
||||
methods: mh,
|
||||
paramChild: paramChildren,
|
||||
anyChild: anyChildren,
|
||||
isLeaf: sc == nil && paramChildren == nil && anyChildren == nil,
|
||||
@ -297,99 +659,77 @@ func (n *node) findChildWithLabel(l byte) *node {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *node) addHandler(method string, h HandlerFunc) {
|
||||
switch method {
|
||||
case http.MethodConnect:
|
||||
n.methodHandler.connect = h
|
||||
case http.MethodDelete:
|
||||
n.methodHandler.delete = h
|
||||
case http.MethodGet:
|
||||
n.methodHandler.get = h
|
||||
case http.MethodHead:
|
||||
n.methodHandler.head = h
|
||||
case http.MethodOptions:
|
||||
n.methodHandler.options = h
|
||||
case http.MethodPatch:
|
||||
n.methodHandler.patch = h
|
||||
case http.MethodPost:
|
||||
n.methodHandler.post = h
|
||||
case PROPFIND:
|
||||
n.methodHandler.propfind = h
|
||||
case http.MethodPut:
|
||||
n.methodHandler.put = h
|
||||
case http.MethodTrace:
|
||||
n.methodHandler.trace = h
|
||||
case REPORT:
|
||||
n.methodHandler.report = h
|
||||
}
|
||||
|
||||
if h != nil {
|
||||
func (n *node) setHandler(method string, r *routeMethod) {
|
||||
n.methods.set(method, r)
|
||||
if r != nil && r.handler != nil {
|
||||
n.isHandler = true
|
||||
} else {
|
||||
n.isHandler = n.methodHandler.isHandler()
|
||||
n.isHandler = n.methods.isHandler()
|
||||
}
|
||||
}
|
||||
|
||||
func (n *node) findHandler(method string) HandlerFunc {
|
||||
switch method {
|
||||
case http.MethodConnect:
|
||||
return n.methodHandler.connect
|
||||
case http.MethodDelete:
|
||||
return n.methodHandler.delete
|
||||
case http.MethodGet:
|
||||
return n.methodHandler.get
|
||||
case http.MethodHead:
|
||||
return n.methodHandler.head
|
||||
case http.MethodOptions:
|
||||
return n.methodHandler.options
|
||||
case http.MethodPatch:
|
||||
return n.methodHandler.patch
|
||||
case http.MethodPost:
|
||||
return n.methodHandler.post
|
||||
case PROPFIND:
|
||||
return n.methodHandler.propfind
|
||||
case http.MethodPut:
|
||||
return n.methodHandler.put
|
||||
case http.MethodTrace:
|
||||
return n.methodHandler.trace
|
||||
case REPORT:
|
||||
return n.methodHandler.report
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
const (
|
||||
// NotFoundRouteName is name of RouteInfo returned when router did not find matching route (404: not found).
|
||||
NotFoundRouteName = "EchoRouteNotFound"
|
||||
// MethodNotAllowedRouteName is name of RouteInfo returned when router did not find matching method for route (404: method not allowed).
|
||||
MethodNotAllowedRouteName = "EchoRouteMethodNotAllowed"
|
||||
)
|
||||
|
||||
// Note: notFoundRouteInfo exists to avoid allocations when setting 404 RouteInfo to RouteMatch
|
||||
var notFoundRouteInfo = &routeInfo{
|
||||
method: "",
|
||||
path: "",
|
||||
params: nil,
|
||||
name: NotFoundRouteName,
|
||||
}
|
||||
|
||||
func (n *node) checkMethodNotAllowed() HandlerFunc {
|
||||
for _, m := range methods {
|
||||
if h := n.findHandler(m); h != nil {
|
||||
return MethodNotAllowedHandler
|
||||
}
|
||||
}
|
||||
return NotFoundHandler
|
||||
// Note: methodNotAllowedRouteInfo exists to avoid allocations when setting 405 RouteInfo to RouteMatch
|
||||
var methodNotAllowedRouteInfo = &routeInfo{
|
||||
method: "",
|
||||
path: "",
|
||||
params: nil,
|
||||
name: MethodNotAllowedRouteName,
|
||||
}
|
||||
|
||||
// Find lookup a handler registered for method and path. It also parses URL for path
|
||||
// parameters and load them into context.
|
||||
// notFoundHandler is handler for 404 cases
|
||||
// Handle returned ErrNotFound errors in Echo.HTTPErrorHandler
|
||||
var notFoundHandler = func(c Context) error {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
// methodNotAllowedHandler is handler for case when route for path+method match was not found (http code 405)
|
||||
// Handle returned ErrMethodNotAllowed errors in Echo.HTTPErrorHandler
|
||||
var methodNotAllowedHandler = func(c Context) error {
|
||||
return ErrMethodNotAllowed
|
||||
}
|
||||
|
||||
// Match looks up a handler registered for method and path. It also parses URL for path parameters and loads them
|
||||
// into context.
|
||||
//
|
||||
// For performance:
|
||||
//
|
||||
// - Get context from `Echo#AcquireContext()`
|
||||
// - Reset it `Context#Reset()`
|
||||
// - Return it `Echo#ReleaseContext()`.
|
||||
func (r *Router) Find(method, path string, c Context) {
|
||||
ctx := c.(*context)
|
||||
ctx.path = path
|
||||
currentNode := r.tree // Current node as root
|
||||
func (r *DefaultRouter) Match(req *http.Request, pathParams *PathParams) RouteMatch {
|
||||
*pathParams = (*pathParams)[0:cap(*pathParams)]
|
||||
|
||||
path := req.URL.Path
|
||||
if !r.useEscapedPathForRouting && req.URL.RawPath != "" {
|
||||
// Difference between URL.RawPath and URL.Path is:
|
||||
// * URL.Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/.
|
||||
// * URL.RawPath is an optional field which only gets set if the default encoding is different from Path.
|
||||
path = req.URL.RawPath
|
||||
}
|
||||
var (
|
||||
currentNode = r.tree // root as current node
|
||||
previousBestMatchNode *node
|
||||
matchedHandler HandlerFunc
|
||||
matchedRouteMethod *routeMethod
|
||||
// search stores the remaining path to check for match. By each iteration we move from start of path to end of the path
|
||||
// and search value gets shorter and shorter.
|
||||
search = path
|
||||
searchIndex = 0
|
||||
paramIndex int // Param counter
|
||||
paramValues = ctx.pvalues // Use the internal slice so the interface can keep the illusion of a dynamic slice
|
||||
paramIndex int // Param counter
|
||||
)
|
||||
|
||||
// Backtracking is needed when a dead end (leaf node) is reached in the router tree.
|
||||
@ -421,8 +761,8 @@ func (r *Router) Find(method, path string, c Context) {
|
||||
paramIndex--
|
||||
// for param/any node.prefix value is always `:` so we can not deduce searchIndex from that and must use pValue
|
||||
// for that index as it would also contain part of path we cut off before moving into node we are backtracking from
|
||||
searchIndex -= len(paramValues[paramIndex])
|
||||
paramValues[paramIndex] = ""
|
||||
searchIndex -= len((*pathParams)[paramIndex].Value)
|
||||
(*pathParams)[paramIndex].Value = ""
|
||||
}
|
||||
search = path[searchIndex:]
|
||||
return
|
||||
@ -456,7 +796,7 @@ func (r *Router) Find(method, path string, c Context) {
|
||||
// No matching prefix, let's backtrack to the first possible alternative node of the decision path
|
||||
nk, ok := backtrackToNextNodeKind(staticKind)
|
||||
if !ok {
|
||||
return // No other possibilities on the decision path
|
||||
break // No other possibilities on the decision path
|
||||
} else if nk == paramKind {
|
||||
goto Param
|
||||
// NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently
|
||||
@ -479,8 +819,8 @@ func (r *Router) Find(method, path string, c Context) {
|
||||
if previousBestMatchNode == nil {
|
||||
previousBestMatchNode = currentNode
|
||||
}
|
||||
if h := currentNode.findHandler(method); h != nil {
|
||||
matchedHandler = h
|
||||
if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
|
||||
matchedRouteMethod = rMethod
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -507,7 +847,7 @@ func (r *Router) Find(method, path string, c Context) {
|
||||
}
|
||||
}
|
||||
|
||||
paramValues[paramIndex] = search[:i]
|
||||
(*pathParams)[paramIndex].Value = search[:i]
|
||||
paramIndex++
|
||||
search = search[i:]
|
||||
searchIndex = searchIndex + i
|
||||
@ -519,7 +859,7 @@ func (r *Router) Find(method, path string, c Context) {
|
||||
if child := currentNode.anyChild; child != nil {
|
||||
// If any node is found, use remaining path for paramValues
|
||||
currentNode = child
|
||||
paramValues[len(currentNode.pnames)-1] = search
|
||||
(*pathParams)[currentNode.paramsCount-1].Value = search
|
||||
// update indexes/search in case we need to backtrack when no handler match is found
|
||||
paramIndex++
|
||||
searchIndex += +len(search)
|
||||
@ -530,8 +870,8 @@ func (r *Router) Find(method, path string, c Context) {
|
||||
if previousBestMatchNode == nil {
|
||||
previousBestMatchNode = currentNode
|
||||
}
|
||||
if h := currentNode.findHandler(method); h != nil {
|
||||
matchedHandler = h
|
||||
if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
|
||||
matchedRouteMethod = rMethod
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -550,20 +890,63 @@ func (r *Router) Find(method, path string, c Context) {
|
||||
}
|
||||
}
|
||||
|
||||
result := RouteMatch{
|
||||
Type: RouteMatchNotFound,
|
||||
Handler: notFoundHandler,
|
||||
RoutePath: "",
|
||||
RouteInfo: notFoundRouteInfo,
|
||||
}
|
||||
if currentNode == nil && previousBestMatchNode == nil {
|
||||
return // nothing matched at all
|
||||
*pathParams = (*pathParams)[0:0]
|
||||
return result // nothing matched at all with given path
|
||||
}
|
||||
|
||||
if matchedHandler != nil {
|
||||
ctx.handler = matchedHandler
|
||||
if matchedRouteMethod != nil {
|
||||
result.Type = RouteMatchFound
|
||||
result.Handler = matchedRouteMethod.handler
|
||||
result.RoutePath = matchedRouteMethod.routeInfo.path
|
||||
result.RouteInfo = matchedRouteMethod.routeInfo
|
||||
} else {
|
||||
// use previous match as basis. although we have no matching handler we have path match.
|
||||
// so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404)
|
||||
currentNode = previousBestMatchNode
|
||||
ctx.handler = currentNode.checkMethodNotAllowed()
|
||||
}
|
||||
ctx.path = currentNode.ppath
|
||||
ctx.pnames = currentNode.pnames
|
||||
|
||||
return
|
||||
// this here is only reason why `RouteMatch.RoutePath` exists. We do not want to create new RouteInfo just for path.
|
||||
result.RoutePath = currentNode.originalPath
|
||||
if currentNode.isHandler {
|
||||
// TODO: in case of OPTIONS method we could respond with list of methods that node has. See https://httpwg.org/specs/rfc7231.html#OPTIONS
|
||||
result.Type = RouteMatchMethodNotAllowed
|
||||
result.Handler = methodNotAllowedHandler
|
||||
result.RouteInfo = methodNotAllowedRouteInfo
|
||||
}
|
||||
}
|
||||
|
||||
*pathParams = (*pathParams)[0:currentNode.paramsCount]
|
||||
if matchedRouteMethod != nil {
|
||||
for i, name := range matchedRouteMethod.params {
|
||||
(*pathParams)[i].Name = name
|
||||
}
|
||||
}
|
||||
|
||||
if r.unescapePathParamValues && currentNode.kind != staticKind {
|
||||
// See issue #1531, #1258 - there are cases when path parameter need to be unescaped
|
||||
for i, p := range *pathParams {
|
||||
tmpVal, err := url.PathUnescape(p.Value)
|
||||
if err == nil { // handle problems by ignoring them.
|
||||
(*pathParams)[i].Value = tmpVal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Get returns path parameter value for given name or default value.
|
||||
func (p PathParams) Get(name string, defaultValue string) string {
|
||||
for _, param := range p {
|
||||
if param.Name == name {
|
||||
return param.Value
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
1421
router_test.go
1421
router_test.go
File diff suppressed because it is too large
Load Diff
220
server.go
Normal file
220
server.go
Normal file
@ -0,0 +1,220 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
stdContext "context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
website = "https://echo.labstack.com"
|
||||
// http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo
|
||||
banner = `
|
||||
____ __
|
||||
/ __/___/ / ___
|
||||
/ _// __/ _ \/ _ \
|
||||
/___/\__/_//_/\___/ %s
|
||||
High performance, minimalist Go web framework
|
||||
%s
|
||||
____________________________________O/_______
|
||||
O\
|
||||
`
|
||||
)
|
||||
|
||||
// StartConfig is for creating configured http.Server instance to start serve http(s) requests with given Echo instance
|
||||
type StartConfig struct {
|
||||
// Address for the server to listen on (if not using custom listener)
|
||||
Address string
|
||||
|
||||
// ListenerNetwork allows setting listener network (see net.Listen for allowed values)
|
||||
// Optional: defaults to "tcp"
|
||||
ListenerNetwork string
|
||||
|
||||
// CertFilesystem is file system used to load certificates and keys (if certs/keys are given as paths)
|
||||
CertFilesystem fs.FS
|
||||
|
||||
// DisableHTTP2 disables supports for HTTP2 in TLS server
|
||||
DisableHTTP2 bool
|
||||
|
||||
// HideBanner does not log Echo banner on server startup
|
||||
HideBanner bool
|
||||
|
||||
// HidePort does not log port on server startup
|
||||
HidePort bool
|
||||
|
||||
// GracefulContext is context that completion signals graceful shutdown start
|
||||
GracefulContext stdContext.Context
|
||||
|
||||
// GracefulTimeout is period which server allows listeners to finish serving ongoing requests. If this time is exceeded process is exited
|
||||
// Defaults to 10 seconds
|
||||
GracefulTimeout time.Duration
|
||||
|
||||
// OnShutdownError allows customization of what happens when (graceful) server Shutdown method returns an error.
|
||||
// Defaults to calling e.logger.Error(err)
|
||||
OnShutdownError func(err error)
|
||||
|
||||
// TLSConfigFunc allows modifying TLS configuration before listener is created with it.
|
||||
TLSConfigFunc func(tlsConfig *tls.Config)
|
||||
|
||||
// ListenerAddrFunc allows getting listener address before server starts serving requests on listener. Useful when
|
||||
// address is set as random (`:0`) port.
|
||||
ListenerAddrFunc func(addr net.Addr)
|
||||
|
||||
// BeforeServeFunc allows customizing/accessing server before server starts serving requests on listener.
|
||||
BeforeServeFunc func(s *http.Server) error
|
||||
}
|
||||
|
||||
// Start starts a HTTP server.
|
||||
func (sc StartConfig) Start(e *Echo) error {
|
||||
logger := e.Logger
|
||||
server := http.Server{
|
||||
Handler: e,
|
||||
ErrorLog: log.New(logger, "", 0),
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config = nil
|
||||
if sc.TLSConfigFunc != nil {
|
||||
tlsConfig = &tls.Config{}
|
||||
configureTLS(&sc, tlsConfig)
|
||||
sc.TLSConfigFunc(tlsConfig)
|
||||
}
|
||||
|
||||
listener, err := createListener(&sc, tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return serve(&sc, &server, listener, logger)
|
||||
}
|
||||
|
||||
// StartTLS starts a HTTPS server.
|
||||
// If `certFile` or `keyFile` is `string` the values are treated as file paths.
|
||||
// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is.
|
||||
func (sc StartConfig) StartTLS(e *Echo, certFile, keyFile interface{}) error {
|
||||
logger := e.Logger
|
||||
s := http.Server{
|
||||
Handler: e,
|
||||
ErrorLog: log.New(logger, "", 0),
|
||||
}
|
||||
|
||||
certFs := sc.CertFilesystem
|
||||
if certFs == nil {
|
||||
certFs = os.DirFS(".")
|
||||
}
|
||||
cert, err := filepathOrContent(certFile, certFs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key, err := filepathOrContent(keyFile, certFs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cer, err := tls.X509KeyPair(cert, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cer}}
|
||||
configureTLS(&sc, tlsConfig)
|
||||
if sc.TLSConfigFunc != nil {
|
||||
sc.TLSConfigFunc(tlsConfig)
|
||||
}
|
||||
|
||||
listener, err := createListener(&sc, tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return serve(&sc, &s, listener, logger)
|
||||
}
|
||||
|
||||
func serve(sc *StartConfig, server *http.Server, listener net.Listener, logger Logger) error {
|
||||
if sc.BeforeServeFunc != nil {
|
||||
if err := sc.BeforeServeFunc(server); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
startupGreetings(sc, logger, listener)
|
||||
|
||||
if sc.GracefulContext != nil {
|
||||
ctx, cancel := stdContext.WithCancel(sc.GracefulContext)
|
||||
defer cancel() // make sure this graceful coroutine will end when serve returns by some other means
|
||||
go gracefulShutdown(ctx, sc, server, logger)
|
||||
}
|
||||
return server.Serve(listener)
|
||||
}
|
||||
|
||||
func configureTLS(sc *StartConfig, tlsConfig *tls.Config) {
|
||||
if !sc.DisableHTTP2 {
|
||||
tlsConfig.NextProtos = append(tlsConfig.NextProtos, "h2")
|
||||
}
|
||||
}
|
||||
|
||||
func createListener(sc *StartConfig, tlsConfig *tls.Config) (net.Listener, error) {
|
||||
listenerNetwork := sc.ListenerNetwork
|
||||
if listenerNetwork == "" {
|
||||
listenerNetwork = "tcp"
|
||||
}
|
||||
|
||||
var listener net.Listener
|
||||
var err error
|
||||
if tlsConfig != nil {
|
||||
listener, err = tls.Listen(listenerNetwork, sc.Address, tlsConfig)
|
||||
} else {
|
||||
listener, err = net.Listen(listenerNetwork, sc.Address)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sc.ListenerAddrFunc != nil {
|
||||
sc.ListenerAddrFunc(listener.Addr())
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func startupGreetings(sc *StartConfig, logger Logger, listener net.Listener) {
|
||||
if !sc.HideBanner {
|
||||
bannerText := fmt.Sprintf(banner, "v"+Version, website)
|
||||
logger.Write([]byte(bannerText))
|
||||
}
|
||||
|
||||
if !sc.HidePort {
|
||||
logger.Write([]byte(fmt.Sprintf("⇨ http(s) server started on %s\n", listener.Addr())))
|
||||
}
|
||||
}
|
||||
|
||||
func filepathOrContent(fileOrContent interface{}, certFilesystem fs.FS) (content []byte, err error) {
|
||||
switch v := fileOrContent.(type) {
|
||||
case string:
|
||||
return fs.ReadFile(certFilesystem, v)
|
||||
case []byte:
|
||||
return v, nil
|
||||
default:
|
||||
return nil, ErrInvalidCertOrKeyType
|
||||
}
|
||||
}
|
||||
|
||||
func gracefulShutdown(gracefulCtx stdContext.Context, sc *StartConfig, server *http.Server, logger Logger) {
|
||||
<-gracefulCtx.Done() // wait until shutdown context is closed.
|
||||
// note: is server if closed by other means this method is still run but is good as no-op
|
||||
|
||||
timeout := sc.GracefulTimeout
|
||||
if timeout == 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
shutdownCtx, cancel := stdContext.WithTimeout(stdContext.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
// we end up here when listeners are not shut down within given timeout
|
||||
if sc.OnShutdownError != nil {
|
||||
sc.OnShutdownError(err)
|
||||
return
|
||||
}
|
||||
logger.Error(fmt.Errorf("failed to shut down server within given timeout: %w", err))
|
||||
}
|
||||
}
|
815
server_test.go
Normal file
815
server_test.go
Normal file
@ -0,0 +1,815 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
stdContext "context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/http2"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func startOnRandomPort(ctx stdContext.Context, e *Echo) (string, error) {
|
||||
addrChan := make(chan string)
|
||||
errCh := make(chan error)
|
||||
|
||||
go func() {
|
||||
errCh <- (&StartConfig{
|
||||
Address: ":0",
|
||||
GracefulContext: ctx,
|
||||
GracefulTimeout: 100 * time.Millisecond,
|
||||
ListenerAddrFunc: func(addr net.Addr) {
|
||||
addrChan <- addr.String()
|
||||
},
|
||||
}).Start(e)
|
||||
}()
|
||||
|
||||
return waitForServerStart(addrChan, errCh)
|
||||
}
|
||||
|
||||
func waitForServerStart(addrChan <-chan string, errCh <-chan error) (string, error) {
|
||||
waitCtx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// wait for addr to arrive
|
||||
for {
|
||||
select {
|
||||
case <-waitCtx.Done():
|
||||
return "", waitCtx.Err()
|
||||
case addr := <-addrChan:
|
||||
return addr, nil
|
||||
case err := <-errCh:
|
||||
if err == http.ErrServerClosed { // was closed normally before listener callback was called. should not be possible
|
||||
return "", nil
|
||||
}
|
||||
// failed to start and we did not manage to get even listener part.
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doGet(url string) (int, string, error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return resp.StatusCode, "", err
|
||||
}
|
||||
return resp.StatusCode, string(body), nil
|
||||
}
|
||||
|
||||
func TestStartConfig_Start(t *testing.T) {
|
||||
e := New()
|
||||
e.GET("/ok", func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
addrChan := make(chan string)
|
||||
errCh := make(chan error)
|
||||
|
||||
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
|
||||
defer shutdown()
|
||||
go func() {
|
||||
errCh <- (&StartConfig{
|
||||
Address: ":0",
|
||||
GracefulContext: ctx,
|
||||
ListenerAddrFunc: func(addr net.Addr) {
|
||||
addrChan <- addr.String()
|
||||
},
|
||||
}).Start(e)
|
||||
}()
|
||||
|
||||
addr, err := waitForServerStart(addrChan, errCh)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// check if server is actually up
|
||||
code, body, err := doGet(fmt.Sprintf("http://%v/ok", addr))
|
||||
if err != nil {
|
||||
assert.NoError(t, err)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, http.StatusOK, code)
|
||||
assert.Equal(t, "OK", body)
|
||||
|
||||
shutdown()
|
||||
|
||||
<-errCh // we will be blocking here until server returns from http.Serve
|
||||
|
||||
// check if server was stopped
|
||||
code, body, err = doGet(fmt.Sprintf("http://%v/ok", addr))
|
||||
assert.Equal(t, 0, code)
|
||||
assert.Equal(t, "", body)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("missing error")
|
||||
return
|
||||
}
|
||||
assert.True(t, strings.Contains(err.Error(), "connect: connection refused"))
|
||||
}
|
||||
|
||||
func TestStartConfig_GracefulShutdown(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenHandlerTakesLonger bool
|
||||
expectBody string
|
||||
expectGracefulError string
|
||||
}{
|
||||
{
|
||||
name: "ok, all handlers returns before graceful shutdown deadline",
|
||||
whenHandlerTakesLonger: false,
|
||||
expectBody: "OK",
|
||||
expectGracefulError: "",
|
||||
},
|
||||
{
|
||||
name: "nok, handlers do not returns before graceful shutdown deadline",
|
||||
whenHandlerTakesLonger: true,
|
||||
expectBody: "timeout",
|
||||
expectGracefulError: stdContext.DeadlineExceeded.Error(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
e.GET("/ok", func(c Context) error {
|
||||
msg := "OK"
|
||||
if tc.whenHandlerTakesLonger {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
msg = "timeout"
|
||||
}
|
||||
return c.String(http.StatusOK, msg)
|
||||
})
|
||||
|
||||
addrChan := make(chan string)
|
||||
errCh := make(chan error)
|
||||
|
||||
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 50*time.Millisecond)
|
||||
defer shutdown()
|
||||
|
||||
shutdownErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- (&StartConfig{
|
||||
Address: ":0",
|
||||
GracefulContext: ctx,
|
||||
GracefulTimeout: 50 * time.Millisecond,
|
||||
OnShutdownError: func(err error) {
|
||||
shutdownErrChan <- err
|
||||
},
|
||||
ListenerAddrFunc: func(addr net.Addr) {
|
||||
addrChan <- addr.String()
|
||||
},
|
||||
}).Start(e)
|
||||
}()
|
||||
|
||||
addr, err := waitForServerStart(addrChan, errCh)
|
||||
assert.NoError(t, err)
|
||||
|
||||
code, body, err := doGet(fmt.Sprintf("http://%v/ok", addr))
|
||||
if err != nil {
|
||||
assert.NoError(t, err)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, http.StatusOK, code)
|
||||
assert.Equal(t, tc.expectBody, body)
|
||||
|
||||
var shutdownErr error
|
||||
select {
|
||||
case shutdownErr = <-shutdownErrChan:
|
||||
default:
|
||||
}
|
||||
if tc.expectGracefulError != "" {
|
||||
assert.EqualError(t, shutdownErr, tc.expectGracefulError)
|
||||
} else {
|
||||
assert.NoError(t, shutdownErr)
|
||||
}
|
||||
|
||||
shutdown()
|
||||
|
||||
<-errCh // we will be blocking here until server returns from http.Serve
|
||||
|
||||
// check if server was stopped
|
||||
code, body, err = doGet(fmt.Sprintf("http://%v/ok", addr))
|
||||
assert.Error(t, err)
|
||||
if err != nil {
|
||||
assert.True(t, strings.Contains(err.Error(), "connect: connection refused"))
|
||||
}
|
||||
assert.Equal(t, 0, code)
|
||||
assert.Equal(t, "", body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartConfig_Start_withTLSConfigFunc(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
tlsConfigCalled := false
|
||||
s := &StartConfig{
|
||||
Address: ":0",
|
||||
TLSConfigFunc: func(tlsConfig *tls.Config) {
|
||||
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, errors.New("not_implemented")
|
||||
}
|
||||
tlsConfigCalled = true
|
||||
},
|
||||
BeforeServeFunc: func(s *http.Server) error {
|
||||
return errors.New("stop_now")
|
||||
},
|
||||
}
|
||||
err := s.Start(e)
|
||||
assert.EqualError(t, err, "stop_now")
|
||||
assert.True(t, tlsConfigCalled)
|
||||
}
|
||||
|
||||
func TestStartConfig_Start_createListenerError(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
s := &StartConfig{
|
||||
Address: ":0",
|
||||
TLSConfigFunc: func(tlsConfig *tls.Config) {
|
||||
},
|
||||
BeforeServeFunc: func(s *http.Server) error {
|
||||
return errors.New("stop_now")
|
||||
},
|
||||
}
|
||||
err := s.Start(e)
|
||||
assert.EqualError(t, err, "tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
|
||||
}
|
||||
|
||||
func TestStartConfig_StartTLS(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
addr string
|
||||
certFile string
|
||||
keyFile string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
addr: ":0",
|
||||
},
|
||||
{
|
||||
name: "nok, invalid certFile",
|
||||
addr: ":0",
|
||||
certFile: "not existing",
|
||||
expectError: "open not existing: no such file or directory",
|
||||
},
|
||||
{
|
||||
name: "nok, invalid keyFile",
|
||||
addr: ":0",
|
||||
keyFile: "not existing",
|
||||
expectError: "open not existing: no such file or directory",
|
||||
},
|
||||
{
|
||||
name: "nok, failed to create cert out of certFile and keyFile",
|
||||
addr: ":0",
|
||||
keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key
|
||||
expectError: "tls: found a certificate rather than a key in the PEM for the private key",
|
||||
},
|
||||
{
|
||||
name: "nok, invalid tls address",
|
||||
addr: "nope",
|
||||
expectError: "listen tcp: address nope: missing port in address",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
addrChan := make(chan string)
|
||||
errCh := make(chan error)
|
||||
|
||||
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
|
||||
defer shutdown()
|
||||
go func() {
|
||||
certFile := "_fixture/certs/cert.pem"
|
||||
if tc.certFile != "" {
|
||||
certFile = tc.certFile
|
||||
}
|
||||
keyFile := "_fixture/certs/key.pem"
|
||||
if tc.keyFile != "" {
|
||||
keyFile = tc.keyFile
|
||||
}
|
||||
|
||||
s := &StartConfig{
|
||||
Address: tc.addr,
|
||||
GracefulContext: ctx,
|
||||
GracefulTimeout: 100 * time.Millisecond,
|
||||
ListenerAddrFunc: func(addr net.Addr) {
|
||||
addrChan <- addr.String()
|
||||
},
|
||||
}
|
||||
errCh <- s.StartTLS(e, certFile, keyFile)
|
||||
}()
|
||||
|
||||
_, err := waitForServerStart(addrChan, errCh)
|
||||
|
||||
if tc.expectError != "" {
|
||||
if _, ok := err.(*os.PathError); ok {
|
||||
assert.Error(t, err) // error messages for unix and windows are different. so name only error type here
|
||||
} else {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartConfig_StartTLS_withTLSConfigFunc(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
tlsConfigCalled := false
|
||||
s := &StartConfig{
|
||||
Address: ":0",
|
||||
TLSConfigFunc: func(tlsConfig *tls.Config) {
|
||||
assert.Len(t, tlsConfig.Certificates, 1)
|
||||
tlsConfigCalled = true
|
||||
},
|
||||
BeforeServeFunc: func(s *http.Server) error {
|
||||
return errors.New("stop_now")
|
||||
},
|
||||
}
|
||||
err := s.StartTLS(e, "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
|
||||
|
||||
assert.EqualError(t, err, "stop_now")
|
||||
assert.True(t, tlsConfigCalled)
|
||||
}
|
||||
|
||||
func TestStartConfig_StartTLSAndStart(t *testing.T) {
|
||||
// We name if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server
|
||||
e := New()
|
||||
e.GET("/", func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
tlsCtx, tlsShutdown := stdContext.WithTimeout(stdContext.Background(), 100*time.Millisecond)
|
||||
defer tlsShutdown()
|
||||
addrTLSChan := make(chan string)
|
||||
errTLSChan := make(chan error)
|
||||
go func() {
|
||||
s := &StartConfig{
|
||||
Address: ":0",
|
||||
GracefulContext: tlsCtx,
|
||||
GracefulTimeout: 100 * time.Millisecond,
|
||||
ListenerAddrFunc: func(addr net.Addr) {
|
||||
addrTLSChan <- addr.String()
|
||||
},
|
||||
}
|
||||
errTLSChan <- s.StartTLS(e, "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
|
||||
}()
|
||||
|
||||
tlsAddr, err := waitForServerStart(addrTLSChan, errTLSChan)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true)
|
||||
client := &http.Client{Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}}
|
||||
res, err := client.Get(fmt.Sprintf("https://%v", tlsAddr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 100*time.Millisecond)
|
||||
defer shutdown()
|
||||
addrChan := make(chan string)
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
s := &StartConfig{
|
||||
Address: ":0",
|
||||
GracefulContext: ctx,
|
||||
GracefulTimeout: 100 * time.Millisecond,
|
||||
ListenerAddrFunc: func(addr net.Addr) {
|
||||
addrChan <- addr.String()
|
||||
},
|
||||
}
|
||||
errChan <- s.Start(e)
|
||||
}()
|
||||
|
||||
addr, err := waitForServerStart(addrChan, errChan)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS
|
||||
res, err = client.Get(fmt.Sprintf("http://%v", addr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
// see if HTTPS works after HTTP listener is also added
|
||||
res, err = client.Get(fmt.Sprintf("https://%v", tlsAddr))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestFilepathOrContent(t *testing.T) {
|
||||
cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
|
||||
require.NoError(t, err)
|
||||
key, err := ioutil.ReadFile("_fixture/certs/key.pem")
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
cert interface{}
|
||||
key interface{}
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: `ValidCertAndKeyFilePath`,
|
||||
cert: "_fixture/certs/cert.pem",
|
||||
key: "_fixture/certs/key.pem",
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: `ValidCertAndKeyByteString`,
|
||||
cert: cert,
|
||||
key: key,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: `InvalidKeyType`,
|
||||
cert: cert,
|
||||
key: 1,
|
||||
expectedErr: ErrInvalidCertOrKeyType,
|
||||
},
|
||||
{
|
||||
name: `InvalidCertType`,
|
||||
cert: 0,
|
||||
key: key,
|
||||
expectedErr: ErrInvalidCertOrKeyType,
|
||||
},
|
||||
{
|
||||
name: `InvalidCertAndKeyTypes`,
|
||||
cert: 0,
|
||||
key: 1,
|
||||
expectedErr: ErrInvalidCertOrKeyType,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
addrChan := make(chan string)
|
||||
errCh := make(chan error)
|
||||
|
||||
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
|
||||
defer shutdown()
|
||||
|
||||
go func() {
|
||||
s := &StartConfig{
|
||||
Address: ":0",
|
||||
CertFilesystem: os.DirFS("."),
|
||||
GracefulContext: ctx,
|
||||
GracefulTimeout: 100 * time.Millisecond,
|
||||
ListenerAddrFunc: func(addr net.Addr) {
|
||||
addrChan <- addr.String()
|
||||
},
|
||||
}
|
||||
errCh <- s.StartTLS(e, tc.cert, tc.key)
|
||||
}()
|
||||
|
||||
_, err := waitForServerStart(addrChan, errCh)
|
||||
if tc.expectedErr != nil {
|
||||
assert.EqualError(t, err, tc.expectedErr.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func supportsIPv6() bool {
|
||||
addrs, _ := net.InterfaceAddrs()
|
||||
for _, addr := range addrs {
|
||||
// Check if any interface has local IPv6 assigned
|
||||
if strings.Contains(addr.String(), "::1") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestStartConfig_WithListenerNetwork(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
network string
|
||||
address string
|
||||
}{
|
||||
{
|
||||
name: "tcp ipv4 address",
|
||||
network: "tcp",
|
||||
address: "127.0.0.1:1323",
|
||||
},
|
||||
{
|
||||
name: "tcp ipv6 address",
|
||||
network: "tcp",
|
||||
address: "[::1]:1323",
|
||||
},
|
||||
{
|
||||
name: "tcp4 ipv4 address",
|
||||
network: "tcp4",
|
||||
address: "127.0.0.1:1323",
|
||||
},
|
||||
{
|
||||
name: "tcp6 ipv6 address",
|
||||
network: "tcp6",
|
||||
address: "[::1]:1323",
|
||||
},
|
||||
}
|
||||
|
||||
hasIPv6 := supportsIPv6()
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if !hasIPv6 && strings.Contains(tc.address, "::") {
|
||||
t.Skip("Skipping testing IPv6 for " + tc.address + ", not available")
|
||||
}
|
||||
|
||||
e := New()
|
||||
e.GET("/ok", func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
addrChan := make(chan string)
|
||||
errCh := make(chan error)
|
||||
|
||||
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
|
||||
defer shutdown()
|
||||
|
||||
go func() {
|
||||
s := &StartConfig{
|
||||
Address: tc.address,
|
||||
ListenerNetwork: tc.network,
|
||||
GracefulContext: ctx,
|
||||
GracefulTimeout: 100 * time.Millisecond,
|
||||
ListenerAddrFunc: func(addr net.Addr) {
|
||||
addrChan <- addr.String()
|
||||
},
|
||||
}
|
||||
errCh <- s.Start(e)
|
||||
}()
|
||||
|
||||
_, err := waitForServerStart(addrChan, errCh)
|
||||
assert.NoError(t, err)
|
||||
|
||||
code, body, err := doGet(fmt.Sprintf("http://%s/ok", tc.address))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, code)
|
||||
assert.Equal(t, "OK", body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartConfig_WithHideBanner(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
hideBanner bool
|
||||
}{
|
||||
{
|
||||
name: "hide banner on startup",
|
||||
hideBanner: true,
|
||||
},
|
||||
{
|
||||
name: "show banner on startup",
|
||||
hideBanner: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
e.Logger = &testLogger{output: buf}
|
||||
|
||||
e.GET("/ok", func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
addrChan := make(chan string)
|
||||
errCh := make(chan error)
|
||||
|
||||
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
|
||||
defer shutdown()
|
||||
|
||||
go func() {
|
||||
_, err := waitForServerStart(addrChan, errCh)
|
||||
errCh <- err
|
||||
shutdown()
|
||||
}()
|
||||
|
||||
s := &StartConfig{
|
||||
Address: ":0",
|
||||
HideBanner: tc.hideBanner,
|
||||
GracefulContext: ctx,
|
||||
GracefulTimeout: 100 * time.Millisecond,
|
||||
ListenerAddrFunc: func(addr net.Addr) {
|
||||
addrChan <- addr.String()
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.Start(e); err != http.ErrServerClosed {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.NoError(t, <-errCh)
|
||||
|
||||
contains := strings.Contains(buf.String(), "High performance, minimalist Go web framework")
|
||||
if tc.hideBanner {
|
||||
assert.False(t, contains)
|
||||
} else {
|
||||
assert.True(t, contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartConfig_WithHidePort(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
hidePort bool
|
||||
}{
|
||||
{
|
||||
name: "hide port on startup",
|
||||
hidePort: true,
|
||||
},
|
||||
{
|
||||
name: "show port on startup",
|
||||
hidePort: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
e.Logger = &testLogger{output: buf}
|
||||
|
||||
e.GET("/ok", func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
addrChan := make(chan string)
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
|
||||
|
||||
go func() {
|
||||
_, err := waitForServerStart(addrChan, errCh)
|
||||
errCh <- err
|
||||
shutdown()
|
||||
}()
|
||||
|
||||
s := &StartConfig{
|
||||
Address: ":0",
|
||||
HidePort: tc.hidePort,
|
||||
GracefulContext: ctx,
|
||||
GracefulTimeout: 100 * time.Millisecond,
|
||||
ListenerAddrFunc: func(addr net.Addr) {
|
||||
addrChan <- addr.String()
|
||||
},
|
||||
}
|
||||
if err := s.Start(e); err != http.ErrServerClosed {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.NoError(t, <-errCh)
|
||||
|
||||
portMsg := fmt.Sprintf("http(s) server started on")
|
||||
contains := strings.Contains(buf.String(), portMsg)
|
||||
if tc.hidePort {
|
||||
assert.False(t, contains)
|
||||
} else {
|
||||
assert.True(t, contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartConfig_WithBeforeServeFunc(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
e.GET("/ok", func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
s := &StartConfig{
|
||||
Address: ":0",
|
||||
BeforeServeFunc: func(s *http.Server) error {
|
||||
return errors.New("is called before serve")
|
||||
},
|
||||
}
|
||||
err := s.Start(e)
|
||||
assert.EqualError(t, err, "is called before serve")
|
||||
}
|
||||
|
||||
func TestWithDisableHTTP2(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
disableHTTP2 bool
|
||||
}{
|
||||
{
|
||||
name: "HTTP2 enabled",
|
||||
disableHTTP2: false,
|
||||
},
|
||||
{
|
||||
name: "HTTP2 disabled",
|
||||
disableHTTP2: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
e.GET("/ok", func(c Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
addrChan := make(chan string)
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
|
||||
defer shutdown()
|
||||
|
||||
go func() {
|
||||
certFile := "_fixture/certs/cert.pem"
|
||||
keyFile := "_fixture/certs/key.pem"
|
||||
|
||||
s := &StartConfig{
|
||||
Address: ":0",
|
||||
DisableHTTP2: tc.disableHTTP2,
|
||||
GracefulContext: ctx,
|
||||
GracefulTimeout: 100 * time.Millisecond,
|
||||
ListenerAddrFunc: func(addr net.Addr) {
|
||||
addrChan <- addr.String()
|
||||
},
|
||||
}
|
||||
errCh <- s.StartTLS(e, certFile, keyFile)
|
||||
}()
|
||||
|
||||
addr, err := waitForServerStart(addrChan, errCh)
|
||||
assert.NoError(t, err)
|
||||
|
||||
url := fmt.Sprintf("https://%v/ok", addr)
|
||||
|
||||
// do ordinary http(s) request
|
||||
client := &http.Client{Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}}
|
||||
res, err := client.Get(url)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
// do HTTP2 request
|
||||
client.Transport = &http2.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
if tc.disableHTTP2 {
|
||||
assert.True(t, strings.Contains(err.Error(), `http2: unexpected ALPN protocol ""; want "h2"`))
|
||||
return
|
||||
}
|
||||
log.Fatalf("Failed get: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed reading response body: %s", err)
|
||||
}
|
||||
assert.Equal(t, "OK", string(body))
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
output io.Writer
|
||||
}
|
||||
|
||||
func (l *testLogger) Write(p []byte) (n int, err error) {
|
||||
return l.output.Write(p)
|
||||
}
|
||||
|
||||
func (l *testLogger) Printf(format string, args ...interface{}) {
|
||||
_, _ = l.output.Write([]byte(fmt.Sprintf(format, args...)))
|
||||
}
|
||||
|
||||
func (l *testLogger) Error(err error) {
|
||||
_, _ = l.output.Write([]byte(err.Error()))
|
||||
}
|
Loading…
Reference in New Issue
Block a user