1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-07 23:01:56 +02:00

Merge branch 'v5_alpha' into v5_alpha_labstack

# Conflicts:
#	.github/workflows/echo.yml
#	Makefile
#	context.go
#	context_test.go
#	echo.go
#	echo_test.go
#	group.go
#	group_test.go
#	middleware/basic_auth.go
#	middleware/basic_auth_test.go
#	middleware/body_limit_test.go
#	middleware/decompress_test.go
#	middleware/extractor.go
#	middleware/jwt.go
#	middleware/jwt_external_test.go
#	middleware/jwt_test.go
#	middleware/key_auth.go
#	middleware/key_auth_test.go
#	middleware/logger.go
#	middleware/method_override_test.go
#	middleware/recover.go
#	middleware/recover_test.go
#	router.go
#	router_test.go
#	server_test.go
This commit is contained in:
toimtoimtoim 2022-07-17 23:38:27 +03:00
commit 74022662be
No known key found for this signature in database
GPG Key ID: 468EA66F309CF886
34 changed files with 2302 additions and 502 deletions

View File

@ -19,7 +19,7 @@ on:
- '_fixture/**'
- '.github/**'
- 'codecov.yml'
workflow_dispatch: # to be able to run workflow manually
workflow_dispatch:
jobs:
test:
@ -29,72 +29,72 @@ jobs:
# Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
# Echo tests with last four major releases
# except v5 starts from 1.17 until there is last four major releases after that
go: [1.17]
go: [1.17, 1.18]
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go }}
- name: Checkout Code
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
ref: ${{ github.ref }}
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Install Dependencies
run: go get -v golang.org/x/lint/golint
run: |
go install golang.org/x/lint/golint@latest
go install honnef.co/go/tools/cmd/staticcheck@latest
- name: Run Tests
run: |
golint -set_exit_status ./...
staticcheck ./...
go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./...
- name: Upload coverage to Codecov
if: success() && matrix.go == 1.17 && matrix.os == 'ubuntu-latest'
uses: codecov/codecov-action@v2
if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest'
uses: codecov/codecov-action@v1
with:
token:
fail_ci_if_error: false
benchmark:
needs: test
strategy:
matrix:
os: [ubuntu-latest]
go: [1.17]
go: [1.18]
name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go }}
- name: Checkout Code (Previous)
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
ref: ${{ github.base_ref }}
path: previous
- name: Checkout Code (New)
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
path: new
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Install Dependencies
run: go get -v golang.org/x/perf/cmd/benchstat
run: go install golang.org/x/perf/cmd/benchstat@latest
- name: Run Benchmark (Previous)
run: |
cd previous
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchmark (New)
run: |
cd new
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchstat
run: |
benchstat previous/benchmark.txt new/benchmark.txt

View File

@ -1,5 +1,50 @@
# Changelog
## v4.7.2 - 2022-03-16
**Fixes**
* Fix nil pointer exception when calling Start again after address binding error [#2131](https://github.com/labstack/echo/pull/2131)
* Fix CSRF middleware not being able to extract token from multipart/form-data form [#2136](https://github.com/labstack/echo/pull/2136)
* Fix Timeout middleware write race [#2126](https://github.com/labstack/echo/pull/2126)
**Enhancements**
* Recover middleware should not log panic for aborted handler [#2134](https://github.com/labstack/echo/pull/2134)
## v4.7.1 - 2022-03-13
**Fixes**
* Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) [#2123](https://github.com/labstack/echo/pull/2123)
**Enhancements**
* Remove some unused code [#2116](https://github.com/labstack/echo/pull/2116)
## v4.7.0 - 2022-03-01
**Enhancements**
* Add JWT, KeyAuth, CSRF multivalue extractors [#2060](https://github.com/labstack/echo/pull/2060)
* Add LogErrorFunc to recover middleware [#2072](https://github.com/labstack/echo/pull/2072)
* Add support for HEAD method query params binding [#2027](https://github.com/labstack/echo/pull/2027)
* Improve filesystem support with echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS [#2064](https://github.com/labstack/echo/pull/2064)
**Fixes**
* Fix X-Real-IP bug, improve tests [#2007](https://github.com/labstack/echo/pull/2007)
* Minor syntax fixes [#1994](https://github.com/labstack/echo/pull/1994), [#2102](https://github.com/labstack/echo/pull/2102), [#2102](https://github.com/labstack/echo/pull/2102)
**General**
* Add cache-control and connection headers [#2103](https://github.com/labstack/echo/pull/2103)
* Add Retry-After header constant [#2078](https://github.com/labstack/echo/pull/2078)
* Upgrade `go` directive in `go.mod` to 1.17 [#2049](https://github.com/labstack/echo/pull/2049)
* Add Pagoda [#2077](https://github.com/labstack/echo/pull/2077) and Souin [#2069](https://github.com/labstack/echo/pull/2069) to 3rd-party middlewares in README
## v4.6.3 - 2022-01-10
**Fixes**

View File

@ -9,7 +9,8 @@ tag:
check: lint vet race ## Check project
init:
@go get -u golang.org/x/lint/golint
@go install golang.org/x/lint/golint@latest
@go install honnef.co/go/tools/cmd/staticcheck@latest
lint: ## Lint the files
@golint -set_exit_status ${PKG_LIST}
@ -29,6 +30,6 @@ benchmark: ## Run benchmarks
help: ## Display this help screen
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
goversion ?= "1.16"
test_version: ## Run tests inside Docker with given version (defaults to 1.16 oldest supported). Example: make test_version goversion=1.16
goversion ?= "1.17"
test_version: ## Run tests inside Docker with given version (defaults to 1.17 oldest supported). Example: make test_version goversion=1.17
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"

197
binder.go
View File

@ -1,6 +1,8 @@
package echo
import (
"encoding"
"encoding/json"
"fmt"
"net/http"
"strconv"
@ -52,8 +54,11 @@ import (
* time
* duration
* BindUnmarshaler() interface
* TextUnmarshaler() interface
* JSONUnmarshaler() interface
* UnixTime() - converts unix time (integer) to time.Time
* UnixTimeNano() - converts unix time with nano second precision (integer) to time.Time
* UnixTimeMilli() - converts unix time with millisecond precision (integer) to time.Time
* UnixTimeNano() - converts unix time with nanosecond precision (integer) to time.Time
* CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error`
*/
@ -204,7 +209,7 @@ func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []st
return b.customFunc(sourceParam, customFunc, false)
}
// MustCustomFunc requires parameter values to exist to be bind with Func. Returns error when value does not exist.
// MustCustomFunc requires parameter values to exist to bind with Func. Returns error when value does not exist.
func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder {
return b.customFunc(sourceParam, customFunc, true)
}
@ -241,7 +246,7 @@ func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder {
return b
}
// MustString requires parameter value to exist to be bind to string variable. Returns error when value does not exist
// MustString requires parameter value to exist to bind to string variable. Returns error when value does not exist
func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder {
if b.failFast && b.errors != nil {
return b
@ -270,7 +275,7 @@ func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder {
return b
}
// MustStrings requires parameter values to exist to be bind to slice of string variables. Returns error when value does not exist
// MustStrings requires parameter values to exist to bind to slice of string variables. Returns error when value does not exist
func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder {
if b.failFast && b.errors != nil {
return b
@ -302,7 +307,7 @@ func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler)
return b
}
// MustBindUnmarshaler requires parameter value to exist to be bind to destination implementing BindUnmarshaler interface.
// MustBindUnmarshaler requires parameter value to exist to bind to destination implementing BindUnmarshaler interface.
// Returns error when value does not exist
func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder {
if b.failFast && b.errors != nil {
@ -321,13 +326,85 @@ func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshal
return b
}
// JSONUnmarshaler binds parameter to destination implementing json.Unmarshaler interface
func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
tmp := b.ValueFunc(sourceParam)
if tmp == "" {
return b
}
if err := dest.UnmarshalJSON([]byte(tmp)); err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err))
}
return b
}
// MustJSONUnmarshaler requires parameter value to exist to bind to destination implementing json.Unmarshaler interface.
// Returns error when value does not exist
func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
tmp := b.ValueFunc(sourceParam)
if tmp == "" {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil))
return b
}
if err := dest.UnmarshalJSON([]byte(tmp)); err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err))
}
return b
}
// TextUnmarshaler binds parameter to destination implementing encoding.TextUnmarshaler interface
func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
tmp := b.ValueFunc(sourceParam)
if tmp == "" {
return b
}
if err := dest.UnmarshalText([]byte(tmp)); err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err))
}
return b
}
// MustTextUnmarshaler requires parameter value to exist to bind to destination implementing encoding.TextUnmarshaler interface.
// Returns error when value does not exist
func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
tmp := b.ValueFunc(sourceParam)
if tmp == "" {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil))
return b
}
if err := dest.UnmarshalText([]byte(tmp)); err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err))
}
return b
}
// BindWithDelimiter binds parameter to destination by suitable conversion function.
// Delimiter is used before conversion to split parameter value to separate values
func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder {
return b.bindWithDelimiter(sourceParam, dest, delimiter, false)
}
// MustBindWithDelimiter requires parameter value to exist to be bind destination by suitable conversion function.
// MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function.
// Delimiter is used before conversion to split parameter value to separate values
func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder {
return b.bindWithDelimiter(sourceParam, dest, delimiter, true)
@ -376,7 +453,7 @@ func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder {
return b.intValue(sourceParam, dest, 64, false)
}
// MustInt64 requires parameter value to exist to be bind to int64 variable. Returns error when value does not exist
// MustInt64 requires parameter value to exist to bind to int64 variable. Returns error when value does not exist
func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder {
return b.intValue(sourceParam, dest, 64, true)
}
@ -386,7 +463,7 @@ func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder {
return b.intValue(sourceParam, dest, 32, false)
}
// MustInt32 requires parameter value to exist to be bind to int32 variable. Returns error when value does not exist
// MustInt32 requires parameter value to exist to bind to int32 variable. Returns error when value does not exist
func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder {
return b.intValue(sourceParam, dest, 32, true)
}
@ -396,7 +473,7 @@ func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder {
return b.intValue(sourceParam, dest, 16, false)
}
// MustInt16 requires parameter value to exist to be bind to int16 variable. Returns error when value does not exist
// MustInt16 requires parameter value to exist to bind to int16 variable. Returns error when value does not exist
func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder {
return b.intValue(sourceParam, dest, 16, true)
}
@ -406,7 +483,7 @@ func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder {
return b.intValue(sourceParam, dest, 8, false)
}
// MustInt8 requires parameter value to exist to be bind to int8 variable. Returns error when value does not exist
// MustInt8 requires parameter value to exist to bind to int8 variable. Returns error when value does not exist
func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder {
return b.intValue(sourceParam, dest, 8, true)
}
@ -416,7 +493,7 @@ func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder {
return b.intValue(sourceParam, dest, 0, false)
}
// MustInt requires parameter value to exist to be bind to int variable. Returns error when value does not exist
// MustInt requires parameter value to exist to bind to int variable. Returns error when value does not exist
func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder {
return b.intValue(sourceParam, dest, 0, true)
}
@ -544,7 +621,7 @@ func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder {
return b.intsValue(sourceParam, dest, false)
}
// MustInt64s requires parameter value to exist to be bind to int64 slice variable. Returns error when value does not exist
// MustInt64s requires parameter value to exist to bind to int64 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder {
return b.intsValue(sourceParam, dest, true)
}
@ -554,7 +631,7 @@ func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder {
return b.intsValue(sourceParam, dest, false)
}
// MustInt32s requires parameter value to exist to be bind to int32 slice variable. Returns error when value does not exist
// MustInt32s requires parameter value to exist to bind to int32 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder {
return b.intsValue(sourceParam, dest, true)
}
@ -564,7 +641,7 @@ func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder {
return b.intsValue(sourceParam, dest, false)
}
// MustInt16s requires parameter value to exist to be bind to int16 slice variable. Returns error when value does not exist
// MustInt16s requires parameter value to exist to bind to int16 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder {
return b.intsValue(sourceParam, dest, true)
}
@ -574,7 +651,7 @@ func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder {
return b.intsValue(sourceParam, dest, false)
}
// MustInt8s requires parameter value to exist to be bind to int8 slice variable. Returns error when value does not exist
// MustInt8s requires parameter value to exist to bind to int8 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder {
return b.intsValue(sourceParam, dest, true)
}
@ -584,7 +661,7 @@ func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder {
return b.intsValue(sourceParam, dest, false)
}
// MustInts requires parameter value to exist to be bind to int slice variable. Returns error when value does not exist
// MustInts requires parameter value to exist to bind to int slice variable. Returns error when value does not exist
func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder {
return b.intsValue(sourceParam, dest, true)
}
@ -594,7 +671,7 @@ func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder {
return b.uintValue(sourceParam, dest, 64, false)
}
// MustUint64 requires parameter value to exist to be bind to uint64 variable. Returns error when value does not exist
// MustUint64 requires parameter value to exist to bind to uint64 variable. Returns error when value does not exist
func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder {
return b.uintValue(sourceParam, dest, 64, true)
}
@ -604,7 +681,7 @@ func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder {
return b.uintValue(sourceParam, dest, 32, false)
}
// MustUint32 requires parameter value to exist to be bind to uint32 variable. Returns error when value does not exist
// MustUint32 requires parameter value to exist to bind to uint32 variable. Returns error when value does not exist
func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder {
return b.uintValue(sourceParam, dest, 32, true)
}
@ -614,7 +691,7 @@ func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder {
return b.uintValue(sourceParam, dest, 16, false)
}
// MustUint16 requires parameter value to exist to be bind to uint16 variable. Returns error when value does not exist
// MustUint16 requires parameter value to exist to bind to uint16 variable. Returns error when value does not exist
func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder {
return b.uintValue(sourceParam, dest, 16, true)
}
@ -624,7 +701,7 @@ func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder {
return b.uintValue(sourceParam, dest, 8, false)
}
// MustUint8 requires parameter value to exist to be bind to uint8 variable. Returns error when value does not exist
// MustUint8 requires parameter value to exist to bind to uint8 variable. Returns error when value does not exist
func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder {
return b.uintValue(sourceParam, dest, 8, true)
}
@ -634,7 +711,7 @@ func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder {
return b.uintValue(sourceParam, dest, 8, false)
}
// MustByte requires parameter value to exist to be bind to byte variable. Returns error when value does not exist
// MustByte requires parameter value to exist to bind to byte variable. Returns error when value does not exist
func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder {
return b.uintValue(sourceParam, dest, 8, true)
}
@ -644,7 +721,7 @@ func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder {
return b.uintValue(sourceParam, dest, 0, false)
}
// MustUint requires parameter value to exist to be bind to uint variable. Returns error when value does not exist
// MustUint requires parameter value to exist to bind to uint variable. Returns error when value does not exist
func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder {
return b.uintValue(sourceParam, dest, 0, true)
}
@ -772,7 +849,7 @@ func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder {
return b.uintsValue(sourceParam, dest, false)
}
// MustUint64s requires parameter value to exist to be bind to uint64 slice variable. Returns error when value does not exist
// MustUint64s requires parameter value to exist to bind to uint64 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder {
return b.uintsValue(sourceParam, dest, true)
}
@ -782,7 +859,7 @@ func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder {
return b.uintsValue(sourceParam, dest, false)
}
// MustUint32s requires parameter value to exist to be bind to uint32 slice variable. Returns error when value does not exist
// MustUint32s requires parameter value to exist to bind to uint32 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder {
return b.uintsValue(sourceParam, dest, true)
}
@ -792,7 +869,7 @@ func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder {
return b.uintsValue(sourceParam, dest, false)
}
// MustUint16s requires parameter value to exist to be bind to uint16 slice variable. Returns error when value does not exist
// MustUint16s requires parameter value to exist to bind to uint16 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder {
return b.uintsValue(sourceParam, dest, true)
}
@ -802,7 +879,7 @@ func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder {
return b.uintsValue(sourceParam, dest, false)
}
// MustUint8s requires parameter value to exist to be bind to uint8 slice variable. Returns error when value does not exist
// MustUint8s requires parameter value to exist to bind to uint8 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder {
return b.uintsValue(sourceParam, dest, true)
}
@ -812,7 +889,7 @@ func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder {
return b.uintsValue(sourceParam, dest, false)
}
// MustUints requires parameter value to exist to be bind to uint slice variable. Returns error when value does not exist
// MustUints requires parameter value to exist to bind to uint slice variable. Returns error when value does not exist
func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder {
return b.uintsValue(sourceParam, dest, true)
}
@ -822,7 +899,7 @@ func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder {
return b.boolValue(sourceParam, dest, false)
}
// MustBool requires parameter value to exist to be bind to bool variable. Returns error when value does not exist
// MustBool requires parameter value to exist to bind to bool variable. Returns error when value does not exist
func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder {
return b.boolValue(sourceParam, dest, true)
}
@ -887,7 +964,7 @@ func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder {
return b.boolsValue(sourceParam, dest, false)
}
// MustBools requires parameter values to exist to be bind to slice of bool variables. Returns error when values does not exist
// MustBools requires parameter values to exist to bind to slice of bool variables. Returns error when values does not exist
func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder {
return b.boolsValue(sourceParam, dest, true)
}
@ -897,7 +974,7 @@ func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder {
return b.floatValue(sourceParam, dest, 64, false)
}
// MustFloat64 requires parameter value to exist to be bind to float64 variable. Returns error when value does not exist
// MustFloat64 requires parameter value to exist to bind to float64 variable. Returns error when value does not exist
func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder {
return b.floatValue(sourceParam, dest, 64, true)
}
@ -907,7 +984,7 @@ func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder {
return b.floatValue(sourceParam, dest, 32, false)
}
// MustFloat32 requires parameter value to exist to be bind to float32 variable. Returns error when value does not exist
// MustFloat32 requires parameter value to exist to bind to float32 variable. Returns error when value does not exist
func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder {
return b.floatValue(sourceParam, dest, 32, true)
}
@ -992,7 +1069,7 @@ func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder
return b.floatsValue(sourceParam, dest, false)
}
// MustFloat64s requires parameter values to exist to be bind to slice of float64 variables. Returns error when values does not exist
// MustFloat64s requires parameter values to exist to bind to slice of float64 variables. Returns error when values does not exist
func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder {
return b.floatsValue(sourceParam, dest, true)
}
@ -1002,7 +1079,7 @@ func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder
return b.floatsValue(sourceParam, dest, false)
}
// MustFloat32s requires parameter values to exist to be bind to slice of float32 variables. Returns error when values does not exist
// MustFloat32s requires parameter values to exist to bind to slice of float32 variables. Returns error when values does not exist
func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder {
return b.floatsValue(sourceParam, dest, true)
}
@ -1012,7 +1089,7 @@ func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) *
return b.time(sourceParam, dest, layout, false)
}
// MustTime requires parameter value to exist to be bind to time.Time variable. Returns error when value does not exist
// MustTime requires parameter value to exist to bind to time.Time variable. Returns error when value does not exist
func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder {
return b.time(sourceParam, dest, layout, true)
}
@ -1043,7 +1120,7 @@ func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string
return b.times(sourceParam, dest, layout, false)
}
// MustTimes requires parameter values to exist to be bind to slice of time.Time variables. Returns error when values does not exist
// MustTimes requires parameter values to exist to bind to slice of time.Time variables. Returns error when values does not exist
func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder {
return b.times(sourceParam, dest, layout, true)
}
@ -1084,7 +1161,7 @@ func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBi
return b.duration(sourceParam, dest, false)
}
// MustDuration requires parameter value to exist to be bind to time.Duration variable. Returns error when value does not exist
// MustDuration requires parameter value to exist to bind to time.Duration variable. Returns error when value does not exist
func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder {
return b.duration(sourceParam, dest, true)
}
@ -1115,7 +1192,7 @@ func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *Valu
return b.durationsValue(sourceParam, dest, false)
}
// MustDurations requires parameter values to exist to be bind to slice of time.Duration variables. Returns error when values does not exist
// MustDurations requires parameter values to exist to bind to slice of time.Duration variables. Returns error when values does not exist
func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder {
return b.durationsValue(sourceParam, dest, true)
}
@ -1161,10 +1238,10 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim
// Note:
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, false, false)
return b.unixTime(sourceParam, dest, false, time.Second)
}
// MustUnixTime requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding
// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding
// to the given Unix time). Returns error when value does not exist.
//
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
@ -1172,10 +1249,31 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder
// Note:
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, true, false)
return b.unixTime(sourceParam, dest, true, time.Second)
}
// UnixTimeNano binds parameter to time.Time variable (in local Time corresponding to the given Unix time in nano second precision).
// UnixTimeMilli binds parameter to time.Time variable (in local time corresponding to the given Unix time in millisecond precision).
//
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
//
// Note:
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, false, time.Millisecond)
}
// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding
// to the given Unix time in millisecond precision). Returns error when value does not exist.
//
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
//
// Note:
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, true, time.Millisecond)
}
// UnixTimeNano binds parameter to time.Time variable (in local time corresponding to the given Unix time in nanosecond precision).
//
// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00
// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00
@ -1185,10 +1283,10 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, false, true)
return b.unixTime(sourceParam, dest, false, time.Nanosecond)
}
// MustUnixTimeNano requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding
// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding
// to the given Unix time value in nano second precision). Returns error when value does not exist.
//
// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00
@ -1199,10 +1297,10 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, true, true)
return b.unixTime(sourceParam, dest, true, time.Nanosecond)
}
func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, isNano bool) *ValueBinder {
func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, precision time.Duration) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
@ -1221,10 +1319,13 @@ func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExi
return b
}
if isNano {
*dest = time.Unix(0, n)
} else {
switch precision {
case time.Second:
*dest = time.Unix(n, 0)
case time.Millisecond:
*dest = time.Unix(n/1e3, (n%1e3)*1e6) // TODO: time.UnixMilli(n) exists since Go1.17 switch to that when min version allows
case time.Nanosecond:
*dest = time.Unix(0, n)
}
return b
}

View File

@ -7,6 +7,7 @@ import (
"fmt"
"github.com/stretchr/testify/assert"
"io"
"math/big"
"net/http"
"net/http/httptest"
"strconv"
@ -55,7 +56,7 @@ func TestBindingError_Error(t *testing.T) {
func TestBindingError_ErrorJSON(t *testing.T) {
err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error"))
resp, err := json.Marshal(err)
resp, _ := json.Marshal(err)
assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp))
}
@ -2188,6 +2189,188 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) {
}
}
func TestValueBinder_JSONUnmarshaler(t *testing.T) {
example := big.NewInt(999)
var testCases = []struct {
name string
givenFailFast bool
givenBindErrors []error
whenURL string
whenMust bool
expectValue big.Int
expectError string
}{
{
name: "ok, binds value",
whenURL: "/search?param=999&param=998",
expectValue: *example,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: big.Int{},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1&param=100",
expectValue: big.Int{},
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope&param=xxx",
expectValue: big.Int{},
expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=999&param=998",
expectValue: *example,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: big.Int{},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1&param=xxx",
expectValue: big.Int{},
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope&param=xxx",
expectValue: big.Int{},
expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
var dest big.Int
var err error
if tc.whenMust {
err = b.MustJSONUnmarshaler("param", &dest).BindError()
} else {
err = b.JSONUnmarshaler("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_TextUnmarshaler(t *testing.T) {
example := big.NewInt(999)
var testCases = []struct {
name string
givenFailFast bool
givenBindErrors []error
whenURL string
whenMust bool
expectValue big.Int
expectError string
}{
{
name: "ok, binds value",
whenURL: "/search?param=999&param=998",
expectValue: *example,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: big.Int{},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1&param=100",
expectValue: big.Int{},
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope&param=xxx",
expectValue: big.Int{},
expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=999&param=998",
expectValue: *example,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: big.Int{},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1&param=xxx",
expectValue: big.Int{},
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope&param=xxx",
expectValue: big.Int{},
expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
var dest big.Int
var err error
if tc.whenMust {
err = b.MustTextUnmarshaler("param", &dest).BindError()
} else {
err = b.TextUnmarshaler("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_BindWithDelimiter_types(t *testing.T) {
var testCases = []struct {
name string
@ -2530,6 +2713,97 @@ func TestValueBinder_UnixTime(t *testing.T) {
}
}
func TestValueBinder_UnixTimeMilli(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140
var testCases = []struct {
name string
givenFailFast bool
givenBindErrors []error
whenURL string
whenMust bool
expectValue time.Time
expectError string
}{
{
name: "ok, binds value, unix time in milliseconds",
whenURL: "/search?param=1647184410140&param=1647184410199",
expectValue: exampleTime,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: time.Time{},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1&param=100",
expectValue: time.Time{},
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope&param=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=1647184410140&param=1647184410199",
expectValue: exampleTime,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: time.Time{},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1&param=100",
expectValue: time.Time{},
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope&param=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := time.Time{}
var err error
if tc.whenMust {
err = b.MustUnixTimeMilli("param", &dest).BindError()
} else {
err = b.UnixTimeMilli("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano())
assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC))
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_UnixTimeNano(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603
exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789

View File

@ -55,6 +55,16 @@ type Context interface {
// PathParam returns path parameter by name.
PathParam(name string) string
// PathParamDefault returns the path parameter or default value for the provided name.
//
// Notes for DefaultRouter implementation:
// Path parameter could be empty for cases like that:
// * route `/release-:version/bin` and request URL is `/release-/bin`
// * route `/api/:version/image.jpg` and request URL is `/api//image.jpg`
// but not when path parameter is last part of route path
// * route `/download/file.:ext` will not match request `/download/file.`
PathParamDefault(name string, defaultValue string) string
// PathParams returns path parameter values.
PathParams() PathParams
@ -176,6 +186,9 @@ type Context interface {
Redirect(code int, url string) error
// Echo returns the `Echo` instance.
//
// WARNING: Remember that Echo public fields and methods are coroutine safe ONLY when you are NOT mutating them
// anywhere in your code after Echo server has started.
Echo() *Echo
}
@ -379,7 +392,10 @@ func (c *DefaultContext) PathParam(name string) string {
return c.pathParams.Get(name, "")
}
// PathParamDefault does not exist as expecting empty path param makes no sense
// PathParamDefault returns the path parameter or default value for the provided name.
func (c *DefaultContext) PathParamDefault(name, defaultValue string) string {
return c.pathParams.Get(name, defaultValue)
}
// PathParams returns path parameter values.
func (c *DefaultContext) PathParams() PathParams {
@ -406,7 +422,8 @@ func (c *DefaultContext) QueryParam(name string) string {
}
// QueryParamDefault returns the query param or default value for the provided name.
// Note: QueryParamDefault does not distinguish if form had no value by that name or value was empty string
// Note: QueryParamDefault does not distinguish if query had no value by that name or value was empty string
// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamDefault("search", "1")`
func (c *DefaultContext) QueryParamDefault(name, defaultValue string) string {
value := c.QueryParam(name)
if value == "" {

View File

@ -448,22 +448,142 @@ func TestContextCookie(t *testing.T) {
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
}
func TestContextPathParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, nil).(*DefaultContext)
params := &PathParams{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
func TestContext_PathParams(t *testing.T) {
var testCases = []struct {
name string
given *PathParams
expect PathParams
}{
{
name: "param exists",
given: &PathParams{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
},
expect: PathParams{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
},
},
{
name: "params is empty",
given: &PathParams{},
expect: PathParams{},
},
}
// ParamNames
c.pathParams = params
assert.EqualValues(t, *params, c.PathParams())
// Param
assert.Equal(t, "501", c.PathParam("fid"))
assert.Equal(t, "", c.PathParam("undefined"))
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, nil)
c.(RoutableContext).SetRawPathParams(tc.given)
assert.EqualValues(t, tc.expect, c.PathParams())
})
}
}
func TestContext_PathParam(t *testing.T) {
var testCases = []struct {
name string
given *PathParams
whenParamName string
expect string
}{
{
name: "param exists",
given: &PathParams{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
},
whenParamName: "uid",
expect: "101",
},
{
name: "multiple same param values exists - return first",
given: &PathParams{
{Name: "uid", Value: "101"},
{Name: "uid", Value: "202"},
{Name: "fid", Value: "501"},
},
whenParamName: "uid",
expect: "101",
},
{
name: "param does not exists",
given: &PathParams{
{Name: "uid", Value: "101"},
},
whenParamName: "nope",
expect: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, nil)
c.(RoutableContext).SetRawPathParams(tc.given)
assert.EqualValues(t, tc.expect, c.PathParam(tc.whenParamName))
})
}
}
func TestContext_PathParamDefault(t *testing.T) {
var testCases = []struct {
name string
given *PathParams
whenParamName string
whenDefaultValue string
expect string
}{
{
name: "param exists",
given: &PathParams{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
},
whenParamName: "uid",
whenDefaultValue: "999",
expect: "101",
},
{
name: "param exists and is empty",
given: &PathParams{
{Name: "uid", Value: ""},
{Name: "fid", Value: "501"},
},
whenParamName: "uid",
whenDefaultValue: "999",
expect: "", // <-- this is different from QueryParamDefault behaviour
},
{
name: "param does not exists",
given: &PathParams{
{Name: "uid", Value: "101"},
},
whenParamName: "nope",
whenDefaultValue: "999",
expect: "999",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, nil)
c.(RoutableContext).SetRawPathParams(tc.given)
assert.EqualValues(t, tc.expect, c.PathParamDefault(tc.whenParamName, tc.whenDefaultValue))
})
}
}
func TestContextGetAndSetParam(t *testing.T) {
@ -568,27 +688,129 @@ func TestContextFormValue(t *testing.T) {
assert.Error(t, err)
}
func TestContextQueryParam(t *testing.T) {
q := make(url.Values)
q.Set("name", "Jon Snow")
q.Set("email", "jon@labstack.com")
req := httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil)
e := New()
c := e.NewContext(req, nil)
func TestContext_QueryParams(t *testing.T) {
var testCases = []struct {
name string
givenURL string
expect url.Values
}{
{
name: "multiple values in url",
givenURL: "/?test=1&test=2&email=jon%40labstack.com",
expect: url.Values{
"test": []string{"1", "2"},
"email": []string{"jon@labstack.com"},
},
},
{
name: "single value in url",
givenURL: "/?nope=1",
expect: url.Values{
"nope": []string{"1"},
},
},
{
name: "no query params in url",
givenURL: "/?",
expect: url.Values{},
},
}
// QueryParam
assert.Equal(t, "Jon Snow", c.QueryParam("name"))
assert.Equal(t, "jon@labstack.com", c.QueryParam("email"))
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
e := New()
c := e.NewContext(req, nil)
// QueryParamDefault
assert.Equal(t, "Jon Snow", c.QueryParamDefault("name", "nope"))
assert.Equal(t, "default", c.QueryParamDefault("missing", "default"))
assert.Equal(t, tc.expect, c.QueryParams())
})
}
}
// QueryParams
assert.Equal(t, url.Values{
"name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"},
}, c.QueryParams())
func TestContext_QueryParam(t *testing.T) {
var testCases = []struct {
name string
givenURL string
whenParamName string
expect string
}{
{
name: "value exists in url",
givenURL: "/?test=1",
whenParamName: "test",
expect: "1",
},
{
name: "multiple values exists in url",
givenURL: "/?test=9&test=8",
whenParamName: "test",
expect: "9", // <-- first value in returned
},
{
name: "value does not exists in url",
givenURL: "/?nope=1",
whenParamName: "test",
expect: "",
},
{
name: "value is empty in url",
givenURL: "/?test=",
whenParamName: "test",
expect: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
e := New()
c := e.NewContext(req, nil)
assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName))
})
}
}
func TestContext_QueryParamDefault(t *testing.T) {
var testCases = []struct {
name string
givenURL string
whenParamName string
whenDefaultValue string
expect string
}{
{
name: "value exists in url",
givenURL: "/?test=1",
whenParamName: "test",
whenDefaultValue: "999",
expect: "1",
},
{
name: "value does not exists in url",
givenURL: "/?nope=1",
whenParamName: "test",
whenDefaultValue: "999",
expect: "999",
},
{
name: "value is empty in url",
givenURL: "/?test=",
whenParamName: "test",
whenDefaultValue: "999",
expect: "999",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
e := New()
c := e.NewContext(req, nil)
assert.Equal(t, tc.expect, c.QueryParamDefault(tc.whenParamName, tc.whenDefaultValue))
})
}
}
func TestContextFormFile(t *testing.T) {

58
echo.go
View File

@ -49,12 +49,15 @@ import (
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"sync"
)
// Echo is the top-level framework instance.
// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics.
//
// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics. This is very likely
// to happen when you access Echo instances through Context.Echo() method.
type Echo struct {
// premiddleware are middlewares that are run for every request before routing is done
premiddleware []MiddlewareFunc
@ -66,8 +69,8 @@ type Echo struct {
routerCreator func(e *Echo) Router
contextPool sync.Pool
// contextPathParamAllocSize holds maximum parameter count for all added routes. This is necessary info for context
// creation time so we can allocate path parameter values slice.
// contextPathParamAllocSize holds maximum parameter count for all added routes. This is necessary info at context
// creation moment so we can allocate path parameter values slice with correct size.
contextPathParamAllocSize int
// NewContextFunc allows using custom context implementations, instead of default *echo.context
@ -150,6 +153,8 @@ const (
PROPFIND = "PROPFIND"
// REPORT Method can be used to get information about a resource, see rfc 3253
REPORT = "REPORT"
// RouteNotFound is special method type for routes handling "route not found" (404) cases
RouteNotFound = "echo_route_not_found"
)
// Headers
@ -181,12 +186,14 @@ const (
HeaderXForwardedSsl = "X-Forwarded-Ssl"
HeaderXUrlScheme = "X-Url-Scheme"
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
HeaderXRealIP = "X-Real-IP"
HeaderXRequestID = "X-Request-ID"
HeaderXCorrelationID = "X-Correlation-ID"
HeaderXRealIP = "X-Real-Ip"
HeaderXRequestID = "X-Request-Id"
HeaderXCorrelationID = "X-Correlation-Id"
HeaderXRequestedWith = "X-Requested-With"
HeaderServer = "Server"
HeaderOrigin = "Origin"
HeaderCacheControl = "Cache-Control"
HeaderConnection = "Connection"
// Access control
HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
@ -403,6 +410,16 @@ func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo
return e.Add(http.MethodTrace, path, h, m...)
}
// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases)
// for current request URL.
// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with
// wildcard/match-any character (`/*`, `/download/*` etc).
//
// Example: `e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })`
func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(RouteNotFound, path, h, m...)
}
// Any registers a new route for all supported HTTP methods and path with matching handler
// in the router with optional route-level middleware. Panics on error.
func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
@ -707,20 +724,26 @@ func newDefaultFS() *defaultFS {
dir, _ := os.Getwd()
return &defaultFS{
prefix: dir,
fs: os.DirFS(dir),
fs: nil,
}
}
func (fs defaultFS) Open(name string) (fs.File, error) {
if fs.fs == nil {
return os.Open(name)
}
return fs.fs.Open(name)
}
func subFS(currentFs fs.FS, root string) (fs.FS, error) {
root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows
if dFS, ok := currentFs.(*defaultFS); ok {
// we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS to
// allow cases when root is given as `../somepath` which is not valid for fs.FS
root = filepath.Join(dFS.prefix, root)
// we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS.
// fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we
// were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs
if isRelativePath(root) {
root = filepath.Join(dFS.prefix, root)
}
return &defaultFS{
prefix: root,
fs: os.DirFS(root),
@ -729,6 +752,21 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) {
return fs.Sub(currentFs, root)
}
func isRelativePath(path string) bool {
if path == "" {
return true
}
if path[0] == '/' {
return false
}
if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 {
// https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names
// https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats
return false
}
return true
}
// MustSubFS creates sub FS from current filesystem or panic on failure.
// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules.
//

View File

@ -336,16 +336,54 @@ func TestEchoStaticRedirectIndex(t *testing.T) {
}
func TestEchoFile(t *testing.T) {
e := New()
ri := e.File("/walle", "_fixture/images/walle.png")
assert.Equal(t, http.MethodGet, ri.Method())
assert.Equal(t, "/walle", ri.Path())
assert.Equal(t, "GET:/walle", ri.Name())
assert.Nil(t, ri.Params())
var testCases = []struct {
name string
givenPath string
givenFile string
whenPath string
expectCode int
expectStartsWith string
}{
{
name: "ok",
givenPath: "/walle",
givenFile: "_fixture/images/walle.png",
whenPath: "/walle",
expectCode: http.StatusOK,
expectStartsWith: string([]byte{0x89, 0x50, 0x4e}),
},
{
name: "ok with relative path",
givenPath: "/",
givenFile: "./go.mod",
whenPath: "/",
expectCode: http.StatusOK,
expectStartsWith: "module github.com/labstack/echo/v",
},
{
name: "nok file does not exist",
givenPath: "/",
givenFile: "./this-file-does-not-exist",
whenPath: "/",
expectCode: http.StatusNotFound,
expectStartsWith: "{\"message\":\"Not Found\"}\n",
},
}
c, b := request(http.MethodGet, "/walle", e)
assert.Equal(t, http.StatusOK, c)
assert.NotEmpty(t, b)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New() // we are using echo.defaultFS instance
e.File(tc.givenPath, tc.givenFile)
c, b := request(http.MethodGet, tc.whenPath, e)
assert.Equal(t, tc.expectCode, c)
if len(b) > len(tc.expectStartsWith) {
b = b[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, b)
})
}
}
func TestEchoMiddleware(t *testing.T) {
@ -880,6 +918,70 @@ func TestEchoGroup(t *testing.T) {
assert.Equal(t, "023", buf.String())
}
func TestEcho_RouteNotFound(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectRoute interface{}
expectCode int
}{
{
name: "404, route to static not found handler /a/c/xx",
whenURL: "/a/c/xx",
expectRoute: "GET /a/c/xx",
expectCode: http.StatusNotFound,
},
{
name: "404, route to path param not found handler /a/:file",
whenURL: "/a/echo.exe",
expectRoute: "GET /a/:file",
expectCode: http.StatusNotFound,
},
{
name: "404, route to any not found handler /*",
whenURL: "/b/echo.exe",
expectRoute: "GET /*",
expectCode: http.StatusNotFound,
},
{
name: "200, route /a/c/df to /a/c/df",
whenURL: "/a/c/df",
expectRoute: "GET /a/c/df",
expectCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
okHandler := func(c Context) error {
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
}
notFoundHandler := func(c Context) error {
return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
}
e.GET("/", okHandler)
e.GET("/a/c/df", okHandler)
e.GET("/a/b*", okHandler)
e.PUT("/*", okHandler)
e.RouteNotFound("/a/c/xx", notFoundHandler) // static
e.RouteNotFound("/a/:file", notFoundHandler) // param
e.RouteNotFound("/*", notFoundHandler) // any
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectCode, rec.Code)
assert.Equal(t, tc.expectRoute, rec.Body.String())
})
}
}
func TestEchoNotFound(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/files", nil)

View File

@ -157,6 +157,13 @@ func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo
return g.Add(http.MethodGet, path, handler, middleware...)
}
// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group.
//
// Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })`
func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(RouteNotFound, path, h, m...)
}
// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
ri, err := g.AddRoute(Route{

View File

@ -303,6 +303,71 @@ func TestGroup_TRACE(t *testing.T) {
assert.Equal(t, `OK`, body)
}
func TestGroup_RouteNotFound(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectRoute interface{}
expectCode int
}{
{
name: "404, route to static not found handler /group/a/c/xx",
whenURL: "/group/a/c/xx",
expectRoute: "GET /group/a/c/xx",
expectCode: http.StatusNotFound,
},
{
name: "404, route to path param not found handler /group/a/:file",
whenURL: "/group/a/echo.exe",
expectRoute: "GET /group/a/:file",
expectCode: http.StatusNotFound,
},
{
name: "404, route to any not found handler /group/*",
whenURL: "/group/b/echo.exe",
expectRoute: "GET /group/*",
expectCode: http.StatusNotFound,
},
{
name: "200, route /group/a/c/df to /group/a/c/df",
whenURL: "/group/a/c/df",
expectRoute: "GET /group/a/c/df",
expectCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
g := e.Group("/group")
okHandler := func(c Context) error {
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
}
notFoundHandler := func(c Context) error {
return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
}
g.GET("/", okHandler)
g.GET("/a/c/df", okHandler)
g.GET("/a/b*", okHandler)
g.PUT("/*", okHandler)
g.RouteNotFound("/a/c/xx", notFoundHandler) // static
g.RouteNotFound("/a/:file", notFoundHandler) // param
g.RouteNotFound("/*", notFoundHandler) // any
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectCode, rec.Code)
assert.Equal(t, tc.expectRoute, rec.Body.String())
})
}
}
func TestGroup_Any(t *testing.T) {
e := New()

142
ip.go
View File

@ -6,6 +6,130 @@ import (
"strings"
)
/**
By: https://github.com/tmshn (See: https://github.com/labstack/echo/pull/1478 , https://github.com/labstack/echox/pull/134 )
Source: https://echo.labstack.com/guide/ip-address/
IP address plays fundamental role in HTTP; it's used for access control, auditing, geo-based access analysis and more.
Echo provides handy method [`Context#RealIP()`](https://godoc.org/github.com/labstack/echo#Context) for that.
However, it is not trivial to retrieve the _real_ IP address from requests especially when you put L7 proxies before the application.
In such situation, _real_ IP needs to be relayed on HTTP layer from proxies to your app, but you must not trust HTTP headers unconditionally.
Otherwise, you might give someone a chance of deceiving you. **A security risk!**
To retrieve IP address reliably/securely, you must let your application be aware of the entire architecture of your infrastructure.
In Echo, this can be done by configuring `Echo#IPExtractor` appropriately.
This guides show you why and how.
> Note: if you dont' set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice.
Let's start from two questions to know the right direction:
1. Do you put any HTTP (L7) proxy in front of the application?
- It includes both cloud solutions (such as AWS ALB or GCP HTTP LB) and OSS ones (such as Nginx, Envoy or Istio ingress gateway).
2. If yes, what HTTP header do your proxies use to pass client IP to the application?
## Case 1. With no proxy
If you put no proxy (e.g.: directory facing to the internet), all you need to (and have to) see is IP address from network layer.
Any HTTP header is untrustable because the clients have full control what headers to be set.
In this case, use `echo.ExtractIPDirect()`.
```go
e.IPExtractor = echo.ExtractIPDirect()
```
## Case 2. With proxies using `X-Forwarded-For` header
[`X-Forwared-For` (XFF)](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) is the popular header
to relay clients' IP addresses.
At each hop on the proxies, they append the request IP address at the end of the header.
Following example diagram illustrates this behavior.
```text
"Origin" > Proxy 1 > Proxy 2 > Your app
(IP: a) (IP: b) (IP: c)
Case 1.
XFF: "" "a" "a, b"
~~~~~~
Case 2.
XFF: "x" "x, a" "x, a, b"
~~~~~~~~~
What your app will see
```
In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is
configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructre".
In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`.
In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`.
```go
e.IPExtractor = echo.ExtractIPFromXFFHeader()
```
By default, it trusts internal IP addresses (loopback, link-local unicast, private-use and unique local address
from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and
[RFC4193](https://tools.ietf.org/html/rfc4193)).
To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s.
E.g.:
```go
e.IPExtractor = echo.ExtractIPFromXFFHeader(
TrustLinkLocal(false),
TrustIPRanges(lbIPRange),
)
```
- Ref: https://godoc.org/github.com/labstack/echo#TrustOption
## Case 3. With proxies using `X-Real-IP` header
`X-Real-IP` is another HTTP header to relay clients' IP addresses, but it carries only one address unlike XFF.
If your proxies set this header, use `ExtractIPFromRealIPHeader(...TrustOption)`.
```go
e.IPExtractor = echo.ExtractIPFromRealIPHeader()
```
Again, it trusts internal IP addresses by default (loopback, link-local unicast, private-use and unique local address
from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and
[RFC4193](https://tools.ietf.org/html/rfc4193)).
To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s.
- Ref: https://godoc.org/github.com/labstack/echo#TrustOption
> **Never forget** to configure the outermost proxy (i.e.; at the edge of your infrastructure) **not to pass through incoming headers**.
> Otherwise there is a chance of fraud, as it is what clients can control.
## About default behavior
In default behavior, Echo sees all of first XFF header, X-Real-IP header and IP from network layer.
As you might already notice, after reading this article, this is not good.
Sole reason this is default is just backward compatibility.
## Private IP ranges
See: https://en.wikipedia.org/wiki/Private_network
Private IPv4 address ranges (RFC 1918):
* 10.0.0.0 10.255.255.255 (24-bit block)
* 172.16.0.0 172.31.255.255 (20-bit block)
* 192.168.0.0 192.168.255.255 (16-bit block)
Private IPv6 address ranges:
* fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA)
*/
type ipChecker struct {
trustLoopback bool
trustLinkLocal bool
@ -52,6 +176,7 @@ func newIPChecker(configs []TrustOption) *ipChecker {
return checker
}
// Go1.16+ added `ip.IsPrivate()` but until that use this implementation
func isPrivateIPRange(ip net.IP) bool {
if ip4 := ip.To4(); ip4 != nil {
return ip4[0] == 10 ||
@ -87,10 +212,12 @@ type IPExtractor func(*http.Request) string
// ExtractIPDirect extracts IP address using actual IP address.
// Use this if your server faces to internet directory (i.e.: uses no proxy).
func ExtractIPDirect() IPExtractor {
return func(req *http.Request) string {
ra, _, _ := net.SplitHostPort(req.RemoteAddr)
return ra
}
return extractIP
}
func extractIP(req *http.Request) string {
ra, _, _ := net.SplitHostPort(req.RemoteAddr)
return ra
}
// ExtractIPFromRealIPHeader extracts IP address using x-real-ip header.
@ -98,14 +225,13 @@ func ExtractIPDirect() IPExtractor {
func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor {
checker := newIPChecker(options)
return func(req *http.Request) string {
directIP := ExtractIPDirect()(req)
realIP := req.Header.Get(HeaderXRealIP)
if realIP != "" {
if ip := net.ParseIP(directIP); ip != nil && checker.trust(ip) {
if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) {
return realIP
}
}
return directIP
return extractIP(req)
}
}
@ -115,7 +241,7 @@ func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor {
func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor {
checker := newIPChecker(options)
return func(req *http.Request) string {
directIP := ExtractIPDirect()(req)
directIP := extractIP(req)
xffs := req.Header[HeaderXForwardedFor]
if len(xffs) == 0 {
return directIP

View File

@ -1,235 +1,606 @@
package echo
import (
"github.com/stretchr/testify/assert"
"net"
"net/http"
"strings"
"testing"
testify "github.com/stretchr/testify/assert"
)
const (
// For RemoteAddr
ipForRemoteAddrLoopback = "127.0.0.1" // From 127.0.0.0/8
sampleRemoteAddrLoopback = ipForRemoteAddrLoopback + ":8080"
ipForRemoteAddrExternal = "203.0.113.1"
sampleRemoteAddrExternal = ipForRemoteAddrExternal + ":8080"
// For x-real-ip
ipForRealIP = "203.0.113.10"
// For XFF
ipForXFF1LinkLocal = "169.254.0.101" // From 169.254.0.0/16
ipForXFF2Private = "192.168.0.102" // From 192.168.0.0/16
ipForXFF3External = "2001:db8::103"
ipForXFF4Private = "fc00::104" // From fc00::/7
ipForXFF5External = "198.51.100.105"
ipForXFF6External = "192.0.2.106"
ipForXFFBroken = "this.is.broken.lol"
// keys for test cases
ipTestReqKeyNoHeader = "no header"
ipTestReqKeyRealIPExternal = "x-real-ip; remote addr external"
ipTestReqKeyRealIPInternal = "x-real-ip; remote addr internal"
ipTestReqKeyRealIPAndXFFExternal = "x-real-ip and xff; remote addr external"
ipTestReqKeyRealIPAndXFFInternal = "x-real-ip and xff; remote addr internal"
ipTestReqKeyXFFExternal = "xff; remote addr external"
ipTestReqKeyXFFInternal = "xff; remote addr internal"
ipTestReqKeyBrokenXFF = "broken xff"
)
var (
sampleXFF = strings.Join([]string{
ipForXFF6External, ipForXFF5External, ipForXFF4Private, ipForXFF3External, ipForXFF2Private, ipForXFF1LinkLocal,
}, ", ")
requests = map[string]*http.Request{
ipTestReqKeyNoHeader: &http.Request{
RemoteAddr: sampleRemoteAddrExternal,
},
ipTestReqKeyRealIPExternal: &http.Request{
Header: http.Header{
"X-Real-Ip": []string{ipForRealIP},
},
RemoteAddr: sampleRemoteAddrExternal,
},
ipTestReqKeyRealIPInternal: &http.Request{
Header: http.Header{
"X-Real-Ip": []string{ipForRealIP},
},
RemoteAddr: sampleRemoteAddrLoopback,
},
ipTestReqKeyRealIPAndXFFExternal: &http.Request{
Header: http.Header{
"X-Real-Ip": []string{ipForRealIP},
HeaderXForwardedFor: []string{sampleXFF},
},
RemoteAddr: sampleRemoteAddrExternal,
},
ipTestReqKeyRealIPAndXFFInternal: &http.Request{
Header: http.Header{
"X-Real-Ip": []string{ipForRealIP},
HeaderXForwardedFor: []string{sampleXFF},
},
RemoteAddr: sampleRemoteAddrLoopback,
},
ipTestReqKeyXFFExternal: &http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{sampleXFF},
},
RemoteAddr: sampleRemoteAddrExternal,
},
ipTestReqKeyXFFInternal: &http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{sampleXFF},
},
RemoteAddr: sampleRemoteAddrLoopback,
},
ipTestReqKeyBrokenXFF: &http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{ipForXFFBroken + ", " + ipForXFF1LinkLocal},
},
RemoteAddr: sampleRemoteAddrLoopback,
},
func mustParseCIDR(s string) *net.IPNet {
_, IPNet, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
)
return IPNet
}
func TestExtractIP(t *testing.T) {
_, ipv4AllRange, _ := net.ParseCIDR("0.0.0.0/0")
_, ipv6AllRange, _ := net.ParseCIDR("::/0")
_, ipForXFF3ExternalRange, _ := net.ParseCIDR(ipForXFF3External + "/48")
_, ipForRemoteAddrExternalRange, _ := net.ParseCIDR(ipForRemoteAddrExternal + "/24")
tests := map[string]*struct {
extractor IPExtractor
expectedIPs map[string]string
func TestIPChecker_TrustOption(t *testing.T) {
var testCases = []struct {
name string
givenOptions []TrustOption
whenIP string
expect bool
}{
"ExtractIPDirect": {
ExtractIPDirect(),
map[string]string{
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal,
ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback,
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback,
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
{
name: "ip is within trust range, trusts additional private IPV6 network",
givenOptions: []TrustOption{
TrustLoopback(false),
TrustLinkLocal(false),
TrustPrivateNet(false),
// this is private IPv6 ip
// CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48
// Address: 2001:0db8:0000:0000:0000:0000:0000:0103
// Range start: 2001:0db8:0000:0000:0000:0000:0000:0000
// Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff
TrustIPRange(mustParseCIDR("2001:db8::103/48")),
},
whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
expect: true,
},
"ExtractIPFromRealIPHeader(default)": {
ExtractIPFromRealIPHeader(),
map[string]string{
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
ipTestReqKeyRealIPInternal: ipForRealIP,
ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal,
ipTestReqKeyRealIPAndXFFInternal: ipForRealIP,
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback,
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
},
},
"ExtractIPFromRealIPHeader(trust only direct-facing proxy)": {
ExtractIPFromRealIPHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)),
map[string]string{
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
ipTestReqKeyRealIPExternal: ipForRealIP,
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
ipTestReqKeyRealIPAndXFFExternal: ipForRealIP,
ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback,
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback,
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
},
},
"ExtractIPFromRealIPHeader(trust direct-facing proxy)": {
ExtractIPFromRealIPHeader(TrustIPRange(ipForRemoteAddrExternalRange)),
map[string]string{
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
ipTestReqKeyRealIPExternal: ipForRealIP,
ipTestReqKeyRealIPInternal: ipForRealIP,
ipTestReqKeyRealIPAndXFFExternal: ipForRealIP,
ipTestReqKeyRealIPAndXFFInternal: ipForRealIP,
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback,
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
},
},
"ExtractIPFromXFFHeader(default)": {
ExtractIPFromXFFHeader(),
map[string]string{
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal,
ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External,
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
ipTestReqKeyXFFInternal: ipForXFF3External,
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
},
},
"ExtractIPFromXFFHeader(trust only direct-facing proxy)": {
ExtractIPFromXFFHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)),
map[string]string{
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
ipTestReqKeyRealIPAndXFFExternal: ipForXFF1LinkLocal,
ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback,
ipTestReqKeyXFFExternal: ipForXFF1LinkLocal,
ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback,
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
},
},
"ExtractIPFromXFFHeader(trust direct-facing proxy)": {
ExtractIPFromXFFHeader(TrustIPRange(ipForRemoteAddrExternalRange)),
map[string]string{
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
ipTestReqKeyRealIPAndXFFExternal: ipForXFF3External,
ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External,
ipTestReqKeyXFFExternal: ipForXFF3External,
ipTestReqKeyXFFInternal: ipForXFF3External,
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
},
},
"ExtractIPFromXFFHeader(trust everything)": {
// This is similar to legacy behavior, but ignores x-real-ip header.
ExtractIPFromXFFHeader(TrustIPRange(ipv4AllRange), TrustIPRange(ipv6AllRange)),
map[string]string{
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
ipTestReqKeyRealIPAndXFFExternal: ipForXFF6External,
ipTestReqKeyRealIPAndXFFInternal: ipForXFF6External,
ipTestReqKeyXFFExternal: ipForXFF6External,
ipTestReqKeyXFFInternal: ipForXFF6External,
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
},
},
"ExtractIPFromXFFHeader(trust ipForXFF3External)": {
// This trusts private network also after "additional" trust ranges unlike `TrustNProxies(1)` doesn't
ExtractIPFromXFFHeader(TrustIPRange(ipForXFF3ExternalRange)),
map[string]string{
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal,
ipTestReqKeyRealIPAndXFFInternal: ipForXFF5External,
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
ipTestReqKeyXFFInternal: ipForXFF5External,
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
{
name: "ip is within trust range, trusts additional private IPV6 network",
givenOptions: []TrustOption{
TrustIPRange(mustParseCIDR("2001:db8::103/48")),
},
whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
expect: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
assert := testify.New(t)
for key, req := range requests {
actual := test.extractor(req)
expected := test.expectedIPs[key]
assert.Equal(expected, actual, "Request: %s", key)
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
checker := newIPChecker(tc.givenOptions)
result := checker.trust(net.ParseIP(tc.whenIP))
assert.Equal(t, tc.expect, result)
})
}
}
func TestTrustIPRange(t *testing.T) {
var testCases = []struct {
name string
givenRange string
whenIP string
expect bool
}{
{
name: "ip is within trust range, IPV6 network range",
// CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48
// Address: 2001:0db8:0000:0000:0000:0000:0000:0103
// Range start: 2001:0db8:0000:0000:0000:0000:0000:0000
// Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff
givenRange: "2001:db8::103/48",
whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
expect: true,
},
{
name: "ip is outside (upper bounds) of trust range, IPV6 network range",
givenRange: "2001:db8::103/48",
whenIP: "2001:0db8:0001:0000:0000:0000:0000:0000",
expect: false,
},
{
name: "ip is outside (lower bounds) of trust range, IPV6 network range",
givenRange: "2001:db8::103/48",
whenIP: "2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff",
expect: false,
},
{
name: "ip is within trust range, IPV4 network range",
// CIDR Notation: 8.8.8.8/24
// Address: 8.8.8.8
// Range start: 8.8.8.0
// Range end: 8.8.8.255
givenRange: "8.8.8.0/24",
whenIP: "8.8.8.8",
expect: true,
},
{
name: "ip is within trust range, IPV4 network range",
// CIDR Notation: 8.8.8.8/24
// Address: 8.8.8.8
// Range start: 8.8.8.0
// Range end: 8.8.8.255
givenRange: "8.8.8.0/24",
whenIP: "8.8.8.8",
expect: true,
},
{
name: "ip is outside (upper bounds) of trust range, IPV4 network range",
givenRange: "8.8.8.0/24",
whenIP: "8.8.9.0",
expect: false,
},
{
name: "ip is outside (lower bounds) of trust range, IPV4 network range",
givenRange: "8.8.8.0/24",
whenIP: "8.8.7.255",
expect: false,
},
{
name: "public ip, trust everything in IPV4 network range",
givenRange: "0.0.0.0/0",
whenIP: "8.8.8.8",
expect: true,
},
{
name: "internal ip, trust everything in IPV4 network range",
givenRange: "0.0.0.0/0",
whenIP: "127.0.10.1",
expect: true,
},
{
name: "public ip, trust everything in IPV6 network range",
givenRange: "::/0",
whenIP: "2a00:1450:4026:805::200e",
expect: true,
},
{
name: "internal ip, trust everything in IPV6 network range",
givenRange: "::/0",
whenIP: "0:0:0:0:0:0:0:1",
expect: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cidr := mustParseCIDR(tc.givenRange)
checker := newIPChecker([]TrustOption{
TrustLoopback(false), // disable to avoid interference
TrustLinkLocal(false), // disable to avoid interference
TrustPrivateNet(false), // disable to avoid interference
TrustIPRange(cidr),
})
result := checker.trust(net.ParseIP(tc.whenIP))
assert.Equal(t, tc.expect, result)
})
}
}
func TestTrustPrivateNet(t *testing.T) {
var testCases = []struct {
name string
whenIP string
expect bool
}{
{
name: "do not trust public IPv4 address",
whenIP: "8.8.8.8",
expect: false,
},
{
name: "do not trust public IPv6 address",
whenIP: "2a00:1450:4026:805::200e",
expect: false,
},
{ // Class A: 10.0.0.0 — 10.255.255.255
name: "do not trust IPv4 just outside of class A (lower bounds)",
whenIP: "9.255.255.255",
expect: false,
},
{
name: "do not trust IPv4 just outside of class A (upper bounds)",
whenIP: "11.0.0.0",
expect: false,
},
{
name: "trust IPv4 of class A (lower bounds)",
whenIP: "10.0.0.0",
expect: true,
},
{
name: "trust IPv4 of class A (upper bounds)",
whenIP: "10.255.255.255",
expect: true,
},
{ // Class B: 172.16.0.0 — 172.31.255.255
name: "do not trust IPv4 just outside of class B (lower bounds)",
whenIP: "172.15.255.255",
expect: false,
},
{
name: "do not trust IPv4 just outside of class B (upper bounds)",
whenIP: "172.32.0.0",
expect: false,
},
{
name: "trust IPv4 of class B (lower bounds)",
whenIP: "172.16.0.0",
expect: true,
},
{
name: "trust IPv4 of class B (upper bounds)",
whenIP: "172.31.255.255",
expect: true,
},
{ // Class C: 192.168.0.0 — 192.168.255.255
name: "do not trust IPv4 just outside of class C (lower bounds)",
whenIP: "192.167.255.255",
expect: false,
},
{
name: "do not trust IPv4 just outside of class C (upper bounds)",
whenIP: "192.169.0.0",
expect: false,
},
{
name: "trust IPv4 of class C (lower bounds)",
whenIP: "192.168.0.0",
expect: true,
},
{
name: "trust IPv4 of class C (upper bounds)",
whenIP: "192.168.255.255",
expect: true,
},
{ // fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA)
// splits the address block in two equally sized halves, fc00::/8 and fd00::/8.
// https://en.wikipedia.org/wiki/Unique_local_address
name: "trust IPv6 private address",
whenIP: "fdfc:3514:2cb3:4bd5::",
expect: true,
},
{
name: "do not trust IPv6 just out of /fd (upper bounds)",
whenIP: "/fe00:0000:0000:0000:0000",
expect: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
checker := newIPChecker([]TrustOption{
TrustLoopback(false), // disable to avoid interference
TrustLinkLocal(false), // disable to avoid interference
TrustPrivateNet(true),
})
result := checker.trust(net.ParseIP(tc.whenIP))
assert.Equal(t, tc.expect, result)
})
}
}
func TestTrustLinkLocal(t *testing.T) {
var testCases = []struct {
name string
whenIP string
expect bool
}{
{
name: "trust link local IPv4 address (lower bounds)",
whenIP: "169.254.0.0",
expect: true,
},
{
name: "trust link local IPv4 address (upper bounds)",
whenIP: "169.254.255.255",
expect: true,
},
{
name: "do not trust link local IPv4 address (outside of lower bounds)",
whenIP: "169.253.255.255",
expect: false,
},
{
name: "do not trust link local IPv4 address (outside of upper bounds)",
whenIP: "169.255.0.0",
expect: false,
},
{
name: "trust link local IPv6 address ",
whenIP: "fe80::1",
expect: true,
},
{
name: "do not trust link local IPv6 address ",
whenIP: "fec0::1",
expect: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
checker := newIPChecker([]TrustOption{
TrustLoopback(false), // disable to avoid interference
TrustPrivateNet(false), // disable to avoid interference
TrustLinkLocal(true),
})
result := checker.trust(net.ParseIP(tc.whenIP))
assert.Equal(t, tc.expect, result)
})
}
}
func TestTrustLoopback(t *testing.T) {
var testCases = []struct {
name string
whenIP string
expect bool
}{
{
name: "trust IPv4 as localhost",
whenIP: "127.0.0.1",
expect: true,
},
{
name: "trust IPv6 as localhost",
whenIP: "::1",
expect: true,
},
{
name: "do not trust public ip as localhost",
whenIP: "8.8.8.8",
expect: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
checker := newIPChecker([]TrustOption{
TrustLinkLocal(false), // disable to avoid interference
TrustPrivateNet(false), // disable to avoid interference
TrustLoopback(true),
})
result := checker.trust(net.ParseIP(tc.whenIP))
assert.Equal(t, tc.expect, result)
})
}
}
func TestExtractIPDirect(t *testing.T) {
var testCases = []struct {
name string
whenRequest http.Request
expectIP string
}{
{
name: "request has no headers, extracts IP from request remote addr",
whenRequest: http.Request{
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.10"},
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from internal IP and has Real-IP header, extractor still extracts internal IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.10"},
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.1",
},
{
name: "request is from external IP and has XFF + Real-IP header, extractor still extracts external IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.10"},
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from internal IP and has XFF + Real-IP header, extractor still extracts internal IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"127.0.0.1"},
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.1",
},
{
name: "request is from external IP and has XFF header, extractor still extracts external IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from internal IP and has XFF header, extractor still extracts internal IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.1",
},
{
name: "request is from internal IP and has INVALID XFF header, extractor still extracts internal IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"this.is.broken.lol, 169.254.0.101"},
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.1",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
extractedIP := ExtractIPDirect()(&tc.whenRequest)
assert.Equal(t, tc.expectIP, extractedIP)
})
}
}
func TestExtractIPFromRealIPHeader(t *testing.T) {
_, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
var testCases = []struct {
name string
givenTrustOptions []TrustOption
whenRequest http.Request
expectIP string
}{
{
name: "request has no headers, extracts IP from request remote addr",
whenRequest: http.Request{
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from external IP has INVALID external X-Real-Ip header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"xxx.yyy.zzz.ccc"}, // <-- this is invalid
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
},
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.199"},
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.199",
},
{
name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
},
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.199"},
HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.199",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
extractedIP := ExtractIPFromRealIPHeader(tc.givenTrustOptions...)(&tc.whenRequest)
assert.Equal(t, tc.expectIP, extractedIP)
})
}
}
func TestExtractIPFromXFFHeader(t *testing.T) {
_, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
var testCases = []struct {
name string
givenTrustOptions []TrustOption
whenRequest http.Request
expectIP string
}{
{
name: "request has no headers, extracts IP from request remote addr",
whenRequest: http.Request{
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request has INVALID external XFF header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"xxx.yyy.zzz.ccc, 127.0.0.2"}, // <-- this is invalid
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.1",
},
{
name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"127.0.0.3, 127.0.0.2, 127.0.0.1"},
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.3",
},
{
name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"203.0.113.199"}, // <-- this is untrusted
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header",
givenTrustOptions: []TrustOption{
TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
},
// from request its seems that request has been proxied through 6 servers.
// 1) 203.0.1.100 (this is external IP set by 203.0.100.100 which we do not trust - could be spoofed)
// 2) 203.0.100.100 (this is outside of our network but set by 203.0.113.199 which we trust to set correct IPs)
// 3) 203.0.113.199 (we trust, for example maybe our proxy from some other office)
// 4) 192.168.1.100 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products)
// 5) 127.0.0.1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing)
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"203.0.1.100, 203.0.100.100, 203.0.113.199, 192.168.1.100"},
},
RemoteAddr: "127.0.0.1:8080", // IP of proxy upstream of our APP
},
expectIP: "203.0.100.100", // this is first trusted IP in XFF chain
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
extractedIP := ExtractIPFromXFFHeader(tc.givenTrustOptions...)(&tc.whenRequest)
assert.Equal(t, tc.expectIP, extractedIP)
})
}
}

View File

@ -4,7 +4,7 @@ import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
@ -72,9 +72,11 @@ func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
continue
}
// Invalid base64 shouldn't be treated as error
// instead should be treated as invalid client input
b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
if errDecode != nil {
lastError = fmt.Errorf("invalid basic auth value: %w", errDecode)
lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode)
continue
}
idx := bytes.IndexByte(b, ':')

View File

@ -56,7 +56,7 @@ func TestBasicAuth(t *testing.T) {
name: "nok, not base64 Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
expectErr: "invalid basic auth value: illegal base64 data at input byte 3",
expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3",
},
{
name: "nok, missing Authorization header",

View File

@ -43,6 +43,7 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
@ -55,6 +56,7 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()

View File

@ -144,7 +144,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
var lastTokenErr error
outer:
for _, extractor := range extractors {
clientTokens, err := extractor(c)
clientTokens, _, err := extractor(c)
if err != nil {
lastExtractorErr = err
continue

View File

@ -79,9 +79,6 @@ func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) {
func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
@ -91,10 +88,10 @@ func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
assert.NoError(t, err)

View File

@ -13,6 +13,24 @@ const (
extractorLimit = 20
)
// ExtractorSource is type to indicate source for extracted value
type ExtractorSource string
const (
// ExtractorSourceHeader means value was extracted from request header
ExtractorSourceHeader ExtractorSource = "header"
// ExtractorSourceQuery means value was extracted from request query parameters
ExtractorSourceQuery ExtractorSource = "query"
// ExtractorSourcePathParam means value was extracted from route path parameters
ExtractorSourcePathParam ExtractorSource = "param"
// ExtractorSourceCookie means value was extracted from request cookies
ExtractorSourceCookie ExtractorSource = "cookie"
// ExtractorSourceForm means value was extracted from request form values
ExtractorSourceForm ExtractorSource = "form"
// ExtractorSourceCustom means value was extracted by custom extractor
ExtractorSourceCustom ExtractorSource = "custom"
)
// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups
type ValueExtractorError struct {
message string
@ -31,7 +49,7 @@ var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing valu
var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"}
// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
type ValuesExtractor func(c echo.Context) ([]string, error)
type ValuesExtractor func(c echo.Context) ([]string, ExtractorSource, error)
func createExtractors(lookups string) ([]ValuesExtractor, error) {
if lookups == "" {
@ -75,10 +93,10 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
prefixLen := len(valuePrefix)
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
header = textproto.CanonicalMIMEHeaderKey(header)
return func(c echo.Context) ([]string, error) {
return func(c echo.Context) ([]string, ExtractorSource, error) {
values := c.Request().Header.Values(header)
if len(values) == 0 {
return nil, errHeaderExtractorValueMissing
return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
}
result := make([]string, 0)
@ -100,30 +118,30 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
if len(result) == 0 {
if prefixLen > 0 {
return nil, errHeaderExtractorValueInvalid
return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid
}
return nil, errHeaderExtractorValueMissing
return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
}
return result, nil
return result, ExtractorSourceHeader, nil
}
}
// valuesFromQuery returns a function that extracts values from the query string.
func valuesFromQuery(param string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
return func(c echo.Context) ([]string, ExtractorSource, error) {
result := c.QueryParams()[param]
if len(result) == 0 {
return nil, errQueryExtractorValueMissing
return nil, ExtractorSourceQuery, errQueryExtractorValueMissing
} else if len(result) > extractorLimit-1 {
result = result[:extractorLimit]
}
return result, nil
return result, ExtractorSourceQuery, nil
}
}
// valuesFromParam returns a function that extracts values from the url param string.
func valuesFromParam(param string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
return func(c echo.Context) ([]string, ExtractorSource, error) {
result := make([]string, 0)
for i, p := range c.PathParams() {
if param == p.Name {
@ -134,18 +152,18 @@ func valuesFromParam(param string) ValuesExtractor {
}
}
if len(result) == 0 {
return nil, errParamExtractorValueMissing
return nil, ExtractorSourcePathParam, errParamExtractorValueMissing
}
return result, nil
return result, ExtractorSourcePathParam, nil
}
}
// valuesFromCookie returns a function that extracts values from the named cookie.
func valuesFromCookie(name string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
return func(c echo.Context) ([]string, ExtractorSource, error) {
cookies := c.Cookies()
if len(cookies) == 0 {
return nil, errCookieExtractorValueMissing
return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
}
result := make([]string, 0)
@ -158,26 +176,26 @@ func valuesFromCookie(name string) ValuesExtractor {
}
}
if len(result) == 0 {
return nil, errCookieExtractorValueMissing
return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
}
return result, nil
return result, ExtractorSourceCookie, nil
}
}
// valuesFromForm returns a function that extracts values from the form field.
func valuesFromForm(name string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
if parseErr := c.Request().ParseForm(); parseErr != nil {
return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr)
return func(c echo.Context) ([]string, ExtractorSource, error) {
if c.Request().Form == nil {
_ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does
}
values := c.Request().Form[name]
if len(values) == 0 {
return nil, errFormExtractorValueMissing
return nil, ExtractorSourceForm, errFormExtractorValueMissing
}
if len(values) > extractorLimit-1 {
values = values[:extractorLimit]
}
result := append([]string{}, values...)
return result, nil
return result, ExtractorSourceForm, nil
}
}

View File

@ -1,9 +1,11 @@
package middleware
import (
"bytes"
"fmt"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
@ -18,6 +20,7 @@ func TestCreateExtractors(t *testing.T) {
givenPathParams echo.PathParams
whenLoopups string
expectValues []string
expectSource ExtractorSource
expectCreateError string
expectError string
}{
@ -30,6 +33,7 @@ func TestCreateExtractors(t *testing.T) {
},
whenLoopups: "header:Authorization:Bearer ",
expectValues: []string{"token"},
expectSource: ExtractorSourceHeader,
},
{
name: "ok, form",
@ -43,6 +47,7 @@ func TestCreateExtractors(t *testing.T) {
},
whenLoopups: "form:name",
expectValues: []string{"Jon Snow"},
expectSource: ExtractorSourceForm,
},
{
name: "ok, cookie",
@ -53,6 +58,7 @@ func TestCreateExtractors(t *testing.T) {
},
whenLoopups: "cookie:_csrf",
expectValues: []string{"token"},
expectSource: ExtractorSourceCookie,
},
{
name: "ok, param",
@ -61,6 +67,7 @@ func TestCreateExtractors(t *testing.T) {
},
whenLoopups: "param:id",
expectValues: []string{"123"},
expectSource: ExtractorSourcePathParam,
},
{
name: "ok, query",
@ -70,6 +77,7 @@ func TestCreateExtractors(t *testing.T) {
},
whenLoopups: "query:id",
expectValues: []string{"999"},
expectSource: ExtractorSourceQuery,
},
{
name: "nok, invalid lookup",
@ -100,8 +108,9 @@ func TestCreateExtractors(t *testing.T) {
assert.NoError(t, err)
for _, e := range extractors {
values, eErr := e(c)
values, source, eErr := e(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, tc.expectSource, source)
if tc.expectError != "" {
assert.EqualError(t, eErr, tc.expectError)
return
@ -226,8 +235,9 @@ func TestValuesFromHeader(t *testing.T) {
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
values, err := extractor(c)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceHeader, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@ -287,8 +297,9 @@ func TestValuesFromQuery(t *testing.T) {
extractor := valuesFromQuery(tc.whenName)
values, err := extractor(c)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceQuery, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@ -366,8 +377,9 @@ func TestValuesFromParam(t *testing.T) {
extractor := valuesFromParam(tc.whenName)
values, err := extractor(c)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourcePathParam, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@ -446,8 +458,9 @@ func TestValuesFromCookie(t *testing.T) {
extractor := valuesFromCookie(tc.whenName)
values, err := extractor(c)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceCookie, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@ -483,6 +496,25 @@ func TestValuesFromForm(t *testing.T) {
return req
}
exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request {
var b bytes.Buffer
w := multipart.NewWriter(&b)
w.WriteField("name", "Jon Snow")
w.WriteField("emails[]", "jon@labstack.com")
if mod != nil {
mod(w)
}
fw, _ := w.CreateFormFile("upload", "my.file")
fw.Write([]byte(`<div>hi</div>`))
w.Close()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String()))
req.Header.Add(echo.HeaderContentType, w.FormDataContentType())
return req
}
var testCases = []struct {
name string
givenRequest *http.Request
@ -504,6 +536,14 @@ func TestValuesFromForm(t *testing.T) {
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "ok, POST multipart/form, multiple value",
givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) {
w.WriteField("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "ok, GET form, single value",
givenRequest: exampleGetFormRequest(nil),
@ -524,16 +564,6 @@ func TestValuesFromForm(t *testing.T) {
whenName: "nope",
expectError: errFormExtractorValueMissing.Error(),
},
{
name: "nok, POST form, form parsing error",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Body = nil
return req
}(),
whenName: "name",
expectError: "valuesFromForm parse form failed: missing form body",
},
{
name: "ok, cut values over extractorLimit",
givenRequest: examplePostFormRequest(func(v *url.Values) {
@ -559,8 +589,9 @@ func TestValuesFromForm(t *testing.T) {
extractor := valuesFromForm(tc.whenName)
values, err := extractor(c)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceForm, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {

View File

@ -64,7 +64,7 @@ type JWTConfig struct {
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
// parsing fails or parsed token is invalid.
// Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library
ParseTokenFunc func(c echo.Context, auth string) (interface{}, error)
ParseTokenFunc func(c echo.Context, auth string, source ExtractorSource) (interface{}, error)
}
// JWTSuccessHandler defines a function which is executed for a valid token.
@ -101,7 +101,7 @@ var DefaultJWTConfig = JWTConfig{
// For missing token, it returns "400 - Bad Request" error.
//
// See: https://jwt.io/introduction
func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc {
func JWT(parseTokenFunc func(c echo.Context, auth string, source ExtractorSource) (interface{}, error)) echo.MiddlewareFunc {
c := DefaultJWTConfig
c.ParseTokenFunc = parseTokenFunc
return JWTWithConfig(c)
@ -152,13 +152,13 @@ func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
var lastExtractorErr error
var lastTokenErr error
for _, extractor := range extractors {
auths, extrErr := extractor(c)
auths, source, extrErr := extractor(c)
if extrErr != nil {
lastExtractorErr = extrErr
continue
}
for _, auth := range auths {
token, err := config.ParseTokenFunc(c, auth)
token, err := config.ParseTokenFunc(c, auth, source)
if err != nil {
lastTokenErr = err
continue

View File

@ -21,7 +21,7 @@ import (
// This is one of the options to provide a token validation key.
// The order of precedence is a user-defined SigningKeys and SigningKey.
// Required if signingKey is not provided
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) {
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string, source middleware.ExtractorSource) (interface{}, error) {
// keyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
@ -41,7 +41,7 @@ func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]in
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}
return func(c echo.Context, auth string) (interface{}, error) {
return func(c echo.Context, auth string, source middleware.ExtractorSource) (interface{}, error) {
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here
if err != nil {
return nil, err

View File

@ -14,7 +14,7 @@ import (
"github.com/stretchr/testify/assert"
)
func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) {
func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
// This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != signingMethod {
@ -23,7 +23,7 @@ func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface
return signingKey, nil
}
return func(c echo.Context, auth string) (interface{}, error) {
return func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc)
if err != nil {
return nil, err
@ -405,7 +405,6 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
{
name: "ok, ErrorHandler is executed",
given: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
ErrorHandler: func(c echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error())
},
@ -424,7 +423,7 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
config := tc.given
parseTokenCalled := false
config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) {
config.ParseTokenFunc = func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
parseTokenCalled = true
return nil, errors.New("parsing failed")
}
@ -453,8 +452,10 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
// with current JWT middleware
signingKey := []byte("secret")
var fromSource ExtractorSource
config := JWTConfig{
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
ParseTokenFunc: func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
fromSource = source
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != "HS256" {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
@ -481,6 +482,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, fromSource, ExtractorSourceHeader)
assert.Equal(t, http.StatusTeapot, res.Code)
}
@ -494,7 +496,7 @@ func TestMustJWTWithConfig_SuccessHandler(t *testing.T) {
})
mw, err := JWTConfig{
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
ParseTokenFunc: func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
return auth, nil
},
SuccessHandler: func(c echo.Context) {
@ -616,7 +618,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) {
mw, err := JWTConfig{
ContinueOnIgnoredError: tc.givenContinueOnIgnoredError,
TokenLookup: tc.givenTokenLookup,
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
ParseTokenFunc: func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
return tc.whenParseReturn, tc.whenParseError
},
ErrorHandler: tc.givenErrorHandler,
@ -648,8 +650,8 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
e.Use(JWTWithConfig(JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
TokenLookupFuncs: []ValuesExtractor{
func(c echo.Context) ([]string, error) {
return []string{c.Request().Header.Get("X-API-Key")}, nil
func(c echo.Context) ([]string, ExtractorSource, error) {
return []string{c.Request().Header.Get("X-API-Key")}, ExtractorSourceCustom, nil
},
},
}))

View File

@ -50,7 +50,7 @@ type KeyAuthConfig struct {
}
// KeyAuthValidator defines a function to validate KeyAuth credentials.
type KeyAuthValidator func(c echo.Context, key string) (bool, error)
type KeyAuthValidator func(c echo.Context, key string, source ExtractorSource) (bool, error)
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
type KeyAuthErrorHandler func(c echo.Context, err error) error
@ -116,13 +116,13 @@ func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
var lastExtractorErr error
var lastValidatorErr error
for _, extractor := range extractors {
keys, extrErr := extractor(c)
keys, source, extrErr := extractor(c)
if extrErr != nil {
lastExtractorErr = extrErr
continue
}
for _, key := range keys {
valid, err := config.Validator(c, key)
valid, err := config.Validator(c, key, source)
if err != nil {
lastValidatorErr = err
continue

View File

@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/assert"
)
func testKeyValidator(c echo.Context, key string) (bool, error) {
func testKeyValidator(c echo.Context, key string, source ExtractorSource) (bool, error) {
switch key {
case "valid-key":
return true, nil
@ -218,6 +218,25 @@ func TestKeyAuthWithConfig(t *testing.T) {
expectHandlerCalled: false,
expectError: "code=401, message=Unauthorized, internal=some user defined error",
},
{
name: "ok, custom validator checks source",
givenRequest: func(req *http.Request) {
q := req.URL.Query()
q.Add("key", "valid-key")
req.URL.RawQuery = q.Encode()
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "query:key"
conf.Validator = func(c echo.Context, key string, source ExtractorSource) (bool, error) {
if source == ExtractorSourceQuery {
return true, nil
}
return false, errors.New("invalid source")
}
},
expectHandlerCalled: true,
},
}
for _, tc := range testCases {
@ -267,7 +286,7 @@ func TestKeyAuthWithConfig_errors(t *testing.T) {
{
name: "ok, no error",
whenConfig: KeyAuthConfig{
Validator: func(c echo.Context, key string) (bool, error) {
Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) {
return false, nil
},
},
@ -283,7 +302,7 @@ func TestKeyAuthWithConfig_errors(t *testing.T) {
name: "ok, extractor source can not be split",
whenConfig: KeyAuthConfig{
KeyLookup: "nope",
Validator: func(c echo.Context, key string) (bool, error) {
Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) {
return false, nil
},
},
@ -293,7 +312,7 @@ func TestKeyAuthWithConfig_errors(t *testing.T) {
name: "ok, no extractors",
whenConfig: KeyAuthConfig{
KeyLookup: "nope:nope",
Validator: func(c echo.Context, key string) (bool, error) {
Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) {
return false, nil
},
},

View File

@ -22,6 +22,8 @@ type LoggerConfig struct {
// Tags to construct the logger format.
//
// - time_unix
// - time_unix_milli
// - time_unix_micro
// - time_unix_nano
// - time_rfc3339
// - time_rfc3339_nano
@ -119,6 +121,10 @@ func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
switch tag {
case "time_unix":
return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10))
case "time_unix_milli":
return buf.WriteString(strconv.FormatInt(time.Now().UnixMilli(), 10))
case "time_unix_micro":
return buf.WriteString(strconv.FormatInt(time.Now().UnixMicro(), 10))
case "time_unix_nano":
return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10))
case "time_rfc3339":

View File

@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"testing"
"time"
@ -172,6 +173,52 @@ func TestLoggerCustomTimestamp(t *testing.T) {
assert.Error(t, err)
}
func TestLoggerTemplateWithTimeUnixMilli(t *testing.T) {
buf := new(bytes.Buffer)
e := echo.New()
e.Use(LoggerWithConfig(LoggerConfig{
Format: `${time_unix_milli}`,
Output: buf,
}))
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
unixMillis, err := strconv.ParseInt(buf.String(), 10, 64)
assert.NoError(t, err)
assert.WithinDuration(t, time.Unix(unixMillis/1000, 0), time.Now(), 3*time.Second)
}
func TestLoggerTemplateWithTimeUnixMicro(t *testing.T) {
buf := new(bytes.Buffer)
e := echo.New()
e.Use(LoggerWithConfig(LoggerConfig{
Format: `${time_unix_micro}`,
Output: buf,
}))
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
unixMicros, err := strconv.ParseInt(buf.String(), 10, 64)
assert.NoError(t, err)
assert.WithinDuration(t, time.Unix(unixMicros/1000000, 0), time.Now(), 3*time.Second)
}
func BenchmarkLoggerWithConfig_withoutMapFields(b *testing.B) {
e := echo.New()

View File

@ -32,7 +32,6 @@ func TestMethodOverride(t *testing.T) {
func TestMethodOverride_formParam(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
@ -53,7 +52,6 @@ func TestMethodOverride_formParam(t *testing.T) {
func TestMethodOverride_queryParam(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}

View File

@ -341,10 +341,9 @@ func TestProxyError(t *testing.T) {
e := echo.New()
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
// Remote unreachable
rec = httptest.NewRecorder()
rec := httptest.NewRecorder()
req.URL.Path = "/api/users"
e.ServeHTTP(rec, req)
assert.Equal(t, "/api/users", req.URL.Path)

View File

@ -2,6 +2,7 @@ package middleware
import (
"fmt"
"net/http"
"runtime"
"github.com/labstack/echo/v5"
@ -63,6 +64,9 @@ func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
defer func() {
if r := recover(); r != nil {
if r == http.ErrAbortHandler {
panic(r)
}
tmpErr, ok := r.(error)
if !ok {
tmpErr = fmt.Errorf("%v", r)

View File

@ -51,6 +51,33 @@ func TestRecover_skipper(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
}
func TestRecoverErrAbortHandler(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Recover()(func(c echo.Context) error {
panic(http.ErrAbortHandler)
})
defer func() {
r := recover()
if r == nil {
assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`")
} else {
if err, ok := r.(error); ok {
assert.ErrorIs(t, err, http.ErrAbortHandler)
} else {
assert.Fail(t, "not of error type")
}
}
}()
hErr := h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.NotContains(t, hErr.Error(), "PANIC RECOVER")
}
func TestRecoverWithConfig(t *testing.T) {
var testCases = []struct {
name string

113
router.go
View File

@ -8,6 +8,15 @@ import (
)
// Router is interface for routing request contexts to registered routes.
//
// Contract between Echo/Context instance and the router:
// * all routes must be added through methods on echo.Echo instance.
// Reason: Echo instance uses RouteInfo.Params() length to allocate slice for paths parameters (see `Echo.contextPathParamAllocSize`).
// * Router must populate Context during Router.Route call with:
// * RoutableContext.SetPath
// * RoutableContext.SetRawPathParams (IMPORTANT! with same slice pointer that c.RawPathParams() returns)
// * RoutableContext.SetRouteInfo
// And optionally can set additional information to Context with RoutableContext.Set
type Router interface {
// Add registers Routable with the Router and returns registered RouteInfo
Add(routable Routable) (RouteInfo, error)
@ -19,11 +28,6 @@ type Router interface {
// Route searches Router for matching route and applies it to the given context. In case when no matching method
// was not found (405) or no matching route exists for path (404), router will return its implementation of 405/404
// handler function.
// When implementing custom Router make sure to populate Context during Router.Route call with:
// * RoutableContext.SetPath
// * RoutableContext.SetRawPathParams (IMPORTANT! with same slice pointer that c.RawPathParams() returns)
// * RoutableContext.SetRouteInfo
// And optionally can set additional information to Context with RoutableContext.Set
Route(c RoutableContext) HandlerFunc
}
@ -210,18 +214,22 @@ type routeMethod struct {
}
type routeMethods struct {
connect *routeMethod
delete *routeMethod
get *routeMethod
head *routeMethod
options *routeMethod
patch *routeMethod
post *routeMethod
propfind *routeMethod
put *routeMethod
trace *routeMethod
report *routeMethod
anyOther map[string]*routeMethod
connect *routeMethod
delete *routeMethod
get *routeMethod
head *routeMethod
options *routeMethod
patch *routeMethod
post *routeMethod
propfind *routeMethod
put *routeMethod
trace *routeMethod
report *routeMethod
anyOther map[string]*routeMethod
// notFoundHandler is handler registered with RouteNotFound method and is executed for 404 cases
notFoundHandler *routeMethod
allowHeader string
}
@ -249,6 +257,9 @@ func (m *routeMethods) set(method string, r *routeMethod) {
m.trace = r
case REPORT:
m.report = r
case RouteNotFound:
m.notFoundHandler = r
return // RouteNotFound/404 is not considered as a handler so no further logic needs to be executed
default:
if m.anyOther == nil {
m.anyOther = make(map[string]*routeMethod)
@ -353,6 +364,7 @@ func (m *routeMethods) isHandler() bool {
m.trace != nil ||
m.report != nil ||
len(m.anyOther) != 0
// RouteNotFound/404 is not considered as a handler
}
// Routes returns all registered routes
@ -611,7 +623,12 @@ func (r *DefaultRouter) insert(t kind, path string, method string, ri routeMetho
}
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
} else if lcpLen < prefixLen {
// Split node
// Split node into two before we insert new node.
// This happens when we are inserting path that is submatch of any existing inserted paths.
// For example, we have node `/test` and now are about to insert `/te/*`. In that case
// 1. overlapping part is `/te` that is used as parent node
// 2. `st` is part from existing node that is not matching - it gets its own node (child to `/te`)
// 3. `/*` is the new part we are about to insert (child to `/te`)
n := newNode(
currentNode.kind,
currentNode.prefix[lcpLen:],
@ -735,10 +752,8 @@ func (n *node) findStaticChild(l byte) *node {
}
func (n *node) findChildWithLabel(l byte) *node {
for _, c := range n.staticChildren {
if c.label == l {
return c
}
if c := n.findStaticChild(l); c != nil {
return c
}
if l == paramLabel {
return n.paramChild
@ -751,11 +766,7 @@ func (n *node) findChildWithLabel(l byte) *node {
func (n *node) setHandler(method string, r *routeMethod) {
n.methods.set(method, r)
if r != nil && r.handler != nil {
n.isHandler = true
} else {
n.isHandler = n.methods.isHandler()
}
n.isHandler = n.methods.isHandler()
}
// Note: notFoundRouteInfo exists to avoid allocations when setting 404 RouteInfo to Context
@ -900,7 +911,7 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
// No matching prefix, let's backtrack to the first possible alternative node of the decision path
nk, ok := backtrackToNextNodeKind(staticKind)
if !ok {
break // No other possibilities on the decision path
break // No other possibilities on the decision path, handler will be whatever context is reset to.
} else if nk == paramKind {
goto Param
// NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently
@ -916,15 +927,21 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
search = search[lcpLen:]
searchIndex = searchIndex + lcpLen
// Finish routing if no remaining search and we are on a node with handler and matching method type
if search == "" && currentNode.isHandler {
// check if current node has handler registered for http method we are looking for. we store currentNode as
// best matching in case we do no find no more routes matching this path+method
if previousBestMatchNode == nil {
previousBestMatchNode = currentNode
}
if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
matchedRouteMethod = rMethod
// Finish routing if is no request path remaining to search
if search == "" {
// in case of node that is handler we have exact method type match or something for 405 to use
if currentNode.isHandler {
// check if current node has handler registered for http method we are looking for. we store currentNode as
// best matching in case we do no find no more routes matching this path+method
if previousBestMatchNode == nil {
previousBestMatchNode = currentNode
}
if h := currentNode.methods.find(req.Method); h != nil {
matchedRouteMethod = h
break
}
} else if currentNode.methods.notFoundHandler != nil {
matchedRouteMethod = currentNode.methods.notFoundHandler
break
}
}
@ -944,7 +961,8 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
i := 0
l := len(search)
if currentNode.isLeaf {
// when param node does not have any children then param node should act similarly to any node - consider all remaining search as match
// when param node does not have any children (path param is last piece of route path) then param node should
// act similarly to any node - consider all remaining search as match
i = l
} else {
for ; i < l && search[i] != '/'; i++ {
@ -969,13 +987,16 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
searchIndex += +len(search)
search = ""
// check if current node has handler registered for http method we are looking for. we store currentNode as
// best matching in case we do no find no more routes matching this path+method
if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
matchedRouteMethod = rMethod
break
}
// we store currentNode as best matching in case we do not find more routes matching this path+method. Needed for 405
if previousBestMatchNode == nil {
previousBestMatchNode = currentNode
}
if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
matchedRouteMethod = rMethod
if currentNode.methods.notFoundHandler != nil {
matchedRouteMethod = currentNode.methods.notFoundHandler
break
}
}
@ -1017,7 +1038,13 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
rPath = currentNode.originalPath
rInfo = notFoundRouteInfo
if currentNode.isHandler {
if currentNode.methods.notFoundHandler != nil {
matchedRouteMethod = currentNode.methods.notFoundHandler
rInfo = matchedRouteMethod.routeInfo
rPath = matchedRouteMethod.path
rHandler = matchedRouteMethod.handler
} else if currentNode.isHandler {
rInfo = methodNotAllowedRouteInfo
c.Set(ContextKeyHeaderAllow, currentNode.methods.allowHeader)

View File

@ -1,7 +1,6 @@
package echo
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
@ -2243,6 +2242,64 @@ func TestRouter_Match_DifferentParamNamesForSamePlace(t *testing.T) {
}
}
// Issue #2164 - this test is meant to document path parameter behaviour when request url has empty value in place
// of the path parameter. As tests show the result is different depending on where parameter exists in the route path.
func TestDefaultRouter_PathParamsCanMatchEmptyValues(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectRoute string
expectParam map[string]string
expectError error
}{
{
name: "ok, route is matched with even empty param is in the middle and between slashes",
whenURL: "/a//b",
expectRoute: "/a/:id/b",
expectParam: map[string]string{"id": ""},
},
{
name: "ok, route is matched with even empty param is in the middle",
whenURL: "/a2/b",
expectRoute: "/a2:id/b",
expectParam: map[string]string{"id": ""},
},
{
name: "ok, route is NOT matched with even empty param is at the end",
whenURL: "/a3/",
expectRoute: "",
expectParam: map[string]string{},
expectError: ErrNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.GET("/a/:id/b", handlerFunc)
e.GET("/a2:id/b", handlerFunc)
e.GET("/a3/:id", handlerFunc)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
c := e.NewContext(req, nil).(*DefaultContext)
handler := e.router.Route(c)
err := handler(c)
if tc.expectError != nil {
assert.Equal(t, tc.expectError, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expectRoute, c.Path())
for param, expectedValue := range tc.expectParam {
assert.Equal(t, expectedValue, c.pathParams.Get(param, ""))
}
checkUnusedParamValues(t, c, tc.expectParam)
})
}
}
// Issue #729
func TestRouterParamAlias(t *testing.T) {
api := []testRoute{
@ -3133,6 +3190,232 @@ func TestDefaultRouter_OptionsMethodHandler(t *testing.T) {
assert.Equal(t, "not empty", body)
}
func TestRouter_RouteWhenNotFoundRouteWithNodeSplitting(t *testing.T) {
e := New()
r := e.router
hf := func(c Context) error {
return c.String(http.StatusOK, c.RouteInfo().Name())
}
r.Add(Route{Method: http.MethodGet, Path: "/test*", Handler: hf, Name: "0"})
r.Add(Route{Method: RouteNotFound, Path: "/test*", Handler: hf, Name: "1"})
r.Add(Route{Method: RouteNotFound, Path: "/test", Handler: hf, Name: "2"})
// Tree before:
// 1 `/`
// 1.1 `*` (any) ID=1
// 1.2 `test` (static) ID=2
// 1.2.1 `*` (any) ID=0
// node with path `test` has routeNotFound handler from previous Add call. Now when we insert `/te/st*` into router tree
// This means that node `test` is split into `te` and `st` nodes and new node `/st*` is inserted.
// On that split `/test` routeNotFound handler must not be lost.
r.Add(Route{Method: http.MethodGet, Path: "/te/st*", Handler: hf, Name: "3"})
// Tree after:
// 1 `/`
// 1.1 `*` (any) ID=1
// 1.2 `te` (static)
// 1.2.1 `st` (static) ID=2
// 1.2.1.1 `*` (any) ID=0
// 1.2.2 `/st` (static)
// 1.2.2.1 `*` (any) ID=3
_, body := request(http.MethodPut, "/test", e)
assert.Equal(t, "2", body)
}
func TestRouter_RouteWhenNotFoundRouteAnyKind(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectRoute interface{}
expectID int
expectParam map[string]string
}{
{
name: "route not existent /xx to not found handler /*",
whenURL: "/xx",
expectRoute: "/*",
expectID: 4,
expectParam: map[string]string{"*": "xx"},
},
{
name: "route not existent /a/xx to not found handler /a/*",
whenURL: "/a/xx",
expectRoute: "/a/*",
expectID: 5,
expectParam: map[string]string{"*": "xx"},
},
{
name: "route not existent /a/c/dxxx to not found handler /a/c/d*",
whenURL: "/a/c/dxxx",
expectRoute: "/a/c/d*",
expectID: 6,
expectParam: map[string]string{"*": "xxx"},
},
{
name: "route /a/c/df to /a/c/df",
whenURL: "/a/c/df",
expectRoute: "/a/c/df",
expectID: 1,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.contextPathParamAllocSize = 1
r := e.router
r.Add(Route{Method: http.MethodGet, Path: "/", Handler: handlerHelper("ID", 0), Name: "0"})
r.Add(Route{Method: http.MethodGet, Path: "/a/c/df", Handler: handlerHelper("ID", 1), Name: "1"})
r.Add(Route{Method: http.MethodGet, Path: "/a/b*", Handler: handlerHelper("ID", 2), Name: "2"})
r.Add(Route{Method: http.MethodPut, Path: "/*", Handler: handlerHelper("ID", 3), Name: "3"})
r.Add(Route{Method: RouteNotFound, Path: "/a/c/d*", Handler: handlerHelper("ID", 6), Name: "6"})
r.Add(Route{Method: RouteNotFound, Path: "/a/*", Handler: handlerHelper("ID", 5), Name: "5"})
r.Add(Route{Method: RouteNotFound, Path: "/*", Handler: handlerHelper("ID", 4), Name: "4"})
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
c := e.NewContext(req, nil).(*DefaultContext)
handler := r.Route(c)
handler(c)
testValue, _ := c.Get("ID").(int)
assert.Equal(t, tc.expectID, testValue)
assert.Equal(t, tc.expectRoute, c.Path())
for param, expectedValue := range tc.expectParam {
assert.Equal(t, expectedValue, c.PathParam(param))
}
checkUnusedParamValues(t, c, tc.expectParam)
})
}
}
func TestRouter_RouteWhenNotFoundRouteParamKind(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectRoute interface{}
expectID int
expectParam map[string]string
}{
{
name: "route not existent /xx to not found handler /:file",
whenURL: "/xx",
expectRoute: "/:file",
expectID: 4,
expectParam: map[string]string{"file": "xx"},
},
{
name: "route not existent /a/xx to not found handler /a/:file",
whenURL: "/a/xx",
expectRoute: "/a/:file",
expectID: 5,
expectParam: map[string]string{"file": "xx"},
},
{
name: "route not existent /a/c/dxxx to not found handler /a/c/d:file",
whenURL: "/a/c/dxxx",
expectRoute: "/a/c/d:file",
expectID: 6,
expectParam: map[string]string{"file": "xxx"},
},
{
name: "route /a/c/df to /a/c/df",
whenURL: "/a/c/df",
expectRoute: "/a/c/df",
expectID: 1,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.contextPathParamAllocSize = 1
r := e.router
r.Add(Route{Method: http.MethodGet, Path: "/", Handler: handlerHelper("ID", 0), Name: "0"})
r.Add(Route{Method: http.MethodGet, Path: "/a/c/df", Handler: handlerHelper("ID", 1), Name: "1"})
r.Add(Route{Method: http.MethodGet, Path: "/a/b*", Handler: handlerHelper("ID", 2), Name: "2"})
r.Add(Route{Method: http.MethodPut, Path: "/*", Handler: handlerHelper("ID", 3), Name: "3"})
r.Add(Route{Method: RouteNotFound, Path: "/a/c/d:file", Handler: handlerHelper("ID", 6), Name: "6"})
r.Add(Route{Method: RouteNotFound, Path: "/a/:file", Handler: handlerHelper("ID", 5), Name: "5"})
r.Add(Route{Method: RouteNotFound, Path: "/:file", Handler: handlerHelper("ID", 4), Name: "4"})
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
c := e.NewContext(req, nil).(*DefaultContext)
handler := r.Route(c)
handler(c)
testValue, _ := c.Get("ID").(int)
assert.Equal(t, tc.expectID, testValue)
assert.Equal(t, tc.expectRoute, c.Path())
for param, expectedValue := range tc.expectParam {
assert.Equal(t, expectedValue, c.PathParam(param))
}
checkUnusedParamValues(t, c, tc.expectParam)
})
}
}
func TestRouter_RouteWhenNotFoundRouteStaticKind(t *testing.T) {
// note: static not found handler is quite silly thing to have but we still support it
var testCases = []struct {
name string
whenURL string
expectRoute interface{}
expectID int
expectParam map[string]string
}{
{
name: "route not existent / to not found handler /",
whenURL: "/",
expectRoute: "/",
expectID: 3,
expectParam: map[string]string{},
},
{
name: "route /a to /a",
whenURL: "/a",
expectRoute: "/a",
expectID: 1,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.contextPathParamAllocSize = 1
r := e.router
r.Add(Route{Method: http.MethodPut, Path: "/", Handler: handlerHelper("ID", 0), Name: "0"})
r.Add(Route{Method: http.MethodGet, Path: "/a", Handler: handlerHelper("ID", 1), Name: "1"})
r.Add(Route{Method: http.MethodPut, Path: "/*", Handler: handlerHelper("ID", 2), Name: "2"})
r.Add(Route{Method: RouteNotFound, Path: "/", Handler: handlerHelper("ID", 3), Name: "3"})
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
c := e.NewContext(req, nil).(*DefaultContext)
handler := r.Route(c)
handler(c)
testValue, _ := c.Get("ID").(int)
assert.Equal(t, tc.expectID, testValue)
assert.Equal(t, tc.expectRoute, c.Path())
for param, expectedValue := range tc.expectParam {
assert.Equal(t, expectedValue, c.PathParam(param))
}
checkUnusedParamValues(t, c, tc.expectParam)
})
}
}
type mySimpleRouter struct {
route Route
}
@ -3227,33 +3510,3 @@ func BenchmarkRouterGooglePlusAPIMisses(b *testing.B) {
func BenchmarkRouterParamsAndAnyAPI(b *testing.B) {
benchmarkRouterRoutes(b, paramAndAnyAPI, paramAndAnyAPIToFind)
}
func (n *node) printTree(pfx string, tail bool) {
p := prefix(tail, pfx, "└── ", "├── ")
fmt.Printf("%s%s, %p: type=%d, parent=%p, handler=%v, paramNames=%v\n", p, n.prefix, n, n.kind, n.parent, n.methods, n.paramsCount)
p = prefix(tail, pfx, " ", "│ ")
children := n.staticChildren
l := len(children)
if n.paramChild != nil {
n.paramChild.printTree(p, n.anyChild == nil && l == 0)
}
if n.anyChild != nil {
n.anyChild.printTree(p, l == 0)
}
for i := 0; i < l-1; i++ {
children[i].printTree(p, false)
}
if l > 0 {
children[l-1].printTree(p, true)
}
}
func prefix(tail bool, p, on, off string) string {
if tail {
return fmt.Sprintf("%s%s", p, on)
}
return fmt.Sprintf("%s%s", p, off)
}

View File

@ -696,8 +696,7 @@ func TestStartConfig_WithHidePort(t *testing.T) {
}
assert.NoError(t, <-errCh)
portMsg := fmt.Sprintf("http(s) server started on")
contains := strings.Contains(buf.String(), portMsg)
contains := strings.Contains(buf.String(), "http(s) server started on")
if tc.hidePort {
assert.False(t, contains)
} else {