diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index d1967212..af6a5a05 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -19,7 +19,7 @@ on: - '_fixture/**' - '.github/**' - 'codecov.yml' - workflow_dispatch: # to be able to run workflow manually + workflow_dispatch: jobs: test: @@ -29,72 +29,72 @@ jobs: # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases # except v5 starts from 1.17 until there is last four major releases after that - go: [1.17] + go: [1.17, 1.18] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v2 - with: - go-version: ${{ matrix.go }} - - name: Checkout Code - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ github.ref }} + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - name: Install Dependencies - run: go get -v golang.org/x/lint/golint + run: | + go install golang.org/x/lint/golint@latest + go install honnef.co/go/tools/cmd/staticcheck@latest - name: Run Tests run: | golint -set_exit_status ./... + staticcheck ./... go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... - - name: Upload coverage to Codecov - if: success() && matrix.go == 1.17 && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v2 + if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest' + uses: codecov/codecov-action@v1 with: + token: fail_ci_if_error: false - benchmark: needs: test strategy: matrix: os: [ubuntu-latest] - go: [1.17] + go: [1.18] name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v2 - with: - go-version: ${{ matrix.go }} - - name: Checkout Code (Previous) - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ github.base_ref }} path: previous - name: Checkout Code (New) - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: path: new + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - name: Install Dependencies - run: go get -v golang.org/x/perf/cmd/benchstat + run: go install golang.org/x/perf/cmd/benchstat@latest - name: Run Benchmark (Previous) run: | cd previous go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt - - name: Run Benchmark (New) run: | cd new go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt - - name: Run Benchstat run: | benchstat previous/benchmark.txt new/benchmark.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 372ed13c..ba75d71f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,50 @@ # Changelog +## v4.7.2 - 2022-03-16 + +**Fixes** + +* Fix nil pointer exception when calling Start again after address binding error [#2131](https://github.com/labstack/echo/pull/2131) +* Fix CSRF middleware not being able to extract token from multipart/form-data form [#2136](https://github.com/labstack/echo/pull/2136) +* Fix Timeout middleware write race [#2126](https://github.com/labstack/echo/pull/2126) + +**Enhancements** + +* Recover middleware should not log panic for aborted handler [#2134](https://github.com/labstack/echo/pull/2134) + + +## v4.7.1 - 2022-03-13 + +**Fixes** + +* Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) [#2123](https://github.com/labstack/echo/pull/2123) + +**Enhancements** + +* Remove some unused code [#2116](https://github.com/labstack/echo/pull/2116) + + +## v4.7.0 - 2022-03-01 + +**Enhancements** + +* Add JWT, KeyAuth, CSRF multivalue extractors [#2060](https://github.com/labstack/echo/pull/2060) +* Add LogErrorFunc to recover middleware [#2072](https://github.com/labstack/echo/pull/2072) +* Add support for HEAD method query params binding [#2027](https://github.com/labstack/echo/pull/2027) +* Improve filesystem support with echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS [#2064](https://github.com/labstack/echo/pull/2064) + +**Fixes** + +* Fix X-Real-IP bug, improve tests [#2007](https://github.com/labstack/echo/pull/2007) +* Minor syntax fixes [#1994](https://github.com/labstack/echo/pull/1994), [#2102](https://github.com/labstack/echo/pull/2102), [#2102](https://github.com/labstack/echo/pull/2102) + +**General** + +* Add cache-control and connection headers [#2103](https://github.com/labstack/echo/pull/2103) +* Add Retry-After header constant [#2078](https://github.com/labstack/echo/pull/2078) +* Upgrade `go` directive in `go.mod` to 1.17 [#2049](https://github.com/labstack/echo/pull/2049) +* Add Pagoda [#2077](https://github.com/labstack/echo/pull/2077) and Souin [#2069](https://github.com/labstack/echo/pull/2069) to 3rd-party middlewares in README + ## v4.6.3 - 2022-01-10 **Fixes** diff --git a/Makefile b/Makefile index 10f9c8f5..8149aeba 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,8 @@ tag: check: lint vet race ## Check project init: - @go get -u golang.org/x/lint/golint + @go install golang.org/x/lint/golint@latest + @go install honnef.co/go/tools/cmd/staticcheck@latest lint: ## Lint the files @golint -set_exit_status ${PKG_LIST} @@ -29,6 +30,6 @@ benchmark: ## Run benchmarks help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.16" -test_version: ## Run tests inside Docker with given version (defaults to 1.16 oldest supported). Example: make test_version goversion=1.16 +goversion ?= "1.17" +test_version: ## Run tests inside Docker with given version (defaults to 1.17 oldest supported). Example: make test_version goversion=1.17 @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" diff --git a/binder.go b/binder.go index 402b80bc..e022a7e8 100644 --- a/binder.go +++ b/binder.go @@ -1,6 +1,8 @@ package echo import ( + "encoding" + "encoding/json" "fmt" "net/http" "strconv" @@ -52,8 +54,11 @@ import ( * time * duration * BindUnmarshaler() interface + * TextUnmarshaler() interface + * JSONUnmarshaler() interface * UnixTime() - converts unix time (integer) to time.Time - * UnixTimeNano() - converts unix time with nano second precision (integer) to time.Time + * UnixTimeMilli() - converts unix time with millisecond precision (integer) to time.Time + * UnixTimeNano() - converts unix time with nanosecond precision (integer) to time.Time * CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error` */ @@ -204,7 +209,7 @@ func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []st return b.customFunc(sourceParam, customFunc, false) } -// MustCustomFunc requires parameter values to exist to be bind with Func. Returns error when value does not exist. +// MustCustomFunc requires parameter values to exist to bind with Func. Returns error when value does not exist. func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder { return b.customFunc(sourceParam, customFunc, true) } @@ -241,7 +246,7 @@ func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder { return b } -// MustString requires parameter value to exist to be bind to string variable. Returns error when value does not exist +// MustString requires parameter value to exist to bind to string variable. Returns error when value does not exist func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder { if b.failFast && b.errors != nil { return b @@ -270,7 +275,7 @@ func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder { return b } -// MustStrings requires parameter values to exist to be bind to slice of string variables. Returns error when value does not exist +// MustStrings requires parameter values to exist to bind to slice of string variables. Returns error when value does not exist func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder { if b.failFast && b.errors != nil { return b @@ -302,7 +307,7 @@ func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler) return b } -// MustBindUnmarshaler requires parameter value to exist to be bind to destination implementing BindUnmarshaler interface. +// MustBindUnmarshaler requires parameter value to exist to bind to destination implementing BindUnmarshaler interface. // Returns error when value does not exist func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder { if b.failFast && b.errors != nil { @@ -321,13 +326,85 @@ func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshal return b } +// JSONUnmarshaler binds parameter to destination implementing json.Unmarshaler interface +func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) + } + return b +} + +// MustJSONUnmarshaler requires parameter value to exist to bind to destination implementing json.Unmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) + } + return b +} + +// TextUnmarshaler binds parameter to destination implementing encoding.TextUnmarshaler interface +func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalText([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) + } + return b +} + +// MustTextUnmarshaler requires parameter value to exist to bind to destination implementing encoding.TextUnmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalText([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) + } + return b +} + // BindWithDelimiter binds parameter to destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { return b.bindWithDelimiter(sourceParam, dest, delimiter, false) } -// MustBindWithDelimiter requires parameter value to exist to be bind destination by suitable conversion function. +// MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { return b.bindWithDelimiter(sourceParam, dest, delimiter, true) @@ -376,7 +453,7 @@ func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder { return b.intValue(sourceParam, dest, 64, false) } -// MustInt64 requires parameter value to exist to be bind to int64 variable. Returns error when value does not exist +// MustInt64 requires parameter value to exist to bind to int64 variable. Returns error when value does not exist func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder { return b.intValue(sourceParam, dest, 64, true) } @@ -386,7 +463,7 @@ func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder { return b.intValue(sourceParam, dest, 32, false) } -// MustInt32 requires parameter value to exist to be bind to int32 variable. Returns error when value does not exist +// MustInt32 requires parameter value to exist to bind to int32 variable. Returns error when value does not exist func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder { return b.intValue(sourceParam, dest, 32, true) } @@ -396,7 +473,7 @@ func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder { return b.intValue(sourceParam, dest, 16, false) } -// MustInt16 requires parameter value to exist to be bind to int16 variable. Returns error when value does not exist +// MustInt16 requires parameter value to exist to bind to int16 variable. Returns error when value does not exist func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder { return b.intValue(sourceParam, dest, 16, true) } @@ -406,7 +483,7 @@ func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder { return b.intValue(sourceParam, dest, 8, false) } -// MustInt8 requires parameter value to exist to be bind to int8 variable. Returns error when value does not exist +// MustInt8 requires parameter value to exist to bind to int8 variable. Returns error when value does not exist func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder { return b.intValue(sourceParam, dest, 8, true) } @@ -416,7 +493,7 @@ func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder { return b.intValue(sourceParam, dest, 0, false) } -// MustInt requires parameter value to exist to be bind to int variable. Returns error when value does not exist +// MustInt requires parameter value to exist to bind to int variable. Returns error when value does not exist func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder { return b.intValue(sourceParam, dest, 0, true) } @@ -544,7 +621,7 @@ func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt64s requires parameter value to exist to be bind to int64 slice variable. Returns error when value does not exist +// MustInt64s requires parameter value to exist to bind to int64 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -554,7 +631,7 @@ func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt32s requires parameter value to exist to be bind to int32 slice variable. Returns error when value does not exist +// MustInt32s requires parameter value to exist to bind to int32 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -564,7 +641,7 @@ func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt16s requires parameter value to exist to be bind to int16 slice variable. Returns error when value does not exist +// MustInt16s requires parameter value to exist to bind to int16 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -574,7 +651,7 @@ func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt8s requires parameter value to exist to be bind to int8 slice variable. Returns error when value does not exist +// MustInt8s requires parameter value to exist to bind to int8 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -584,7 +661,7 @@ func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInts requires parameter value to exist to be bind to int slice variable. Returns error when value does not exist +// MustInts requires parameter value to exist to bind to int slice variable. Returns error when value does not exist func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -594,7 +671,7 @@ func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder { return b.uintValue(sourceParam, dest, 64, false) } -// MustUint64 requires parameter value to exist to be bind to uint64 variable. Returns error when value does not exist +// MustUint64 requires parameter value to exist to bind to uint64 variable. Returns error when value does not exist func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder { return b.uintValue(sourceParam, dest, 64, true) } @@ -604,7 +681,7 @@ func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder { return b.uintValue(sourceParam, dest, 32, false) } -// MustUint32 requires parameter value to exist to be bind to uint32 variable. Returns error when value does not exist +// MustUint32 requires parameter value to exist to bind to uint32 variable. Returns error when value does not exist func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder { return b.uintValue(sourceParam, dest, 32, true) } @@ -614,7 +691,7 @@ func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder { return b.uintValue(sourceParam, dest, 16, false) } -// MustUint16 requires parameter value to exist to be bind to uint16 variable. Returns error when value does not exist +// MustUint16 requires parameter value to exist to bind to uint16 variable. Returns error when value does not exist func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder { return b.uintValue(sourceParam, dest, 16, true) } @@ -624,7 +701,7 @@ func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder { return b.uintValue(sourceParam, dest, 8, false) } -// MustUint8 requires parameter value to exist to be bind to uint8 variable. Returns error when value does not exist +// MustUint8 requires parameter value to exist to bind to uint8 variable. Returns error when value does not exist func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder { return b.uintValue(sourceParam, dest, 8, true) } @@ -634,7 +711,7 @@ func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder { return b.uintValue(sourceParam, dest, 8, false) } -// MustByte requires parameter value to exist to be bind to byte variable. Returns error when value does not exist +// MustByte requires parameter value to exist to bind to byte variable. Returns error when value does not exist func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder { return b.uintValue(sourceParam, dest, 8, true) } @@ -644,7 +721,7 @@ func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder { return b.uintValue(sourceParam, dest, 0, false) } -// MustUint requires parameter value to exist to be bind to uint variable. Returns error when value does not exist +// MustUint requires parameter value to exist to bind to uint variable. Returns error when value does not exist func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder { return b.uintValue(sourceParam, dest, 0, true) } @@ -772,7 +849,7 @@ func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint64s requires parameter value to exist to be bind to uint64 slice variable. Returns error when value does not exist +// MustUint64s requires parameter value to exist to bind to uint64 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -782,7 +859,7 @@ func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint32s requires parameter value to exist to be bind to uint32 slice variable. Returns error when value does not exist +// MustUint32s requires parameter value to exist to bind to uint32 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -792,7 +869,7 @@ func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint16s requires parameter value to exist to be bind to uint16 slice variable. Returns error when value does not exist +// MustUint16s requires parameter value to exist to bind to uint16 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -802,7 +879,7 @@ func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint8s requires parameter value to exist to be bind to uint8 slice variable. Returns error when value does not exist +// MustUint8s requires parameter value to exist to bind to uint8 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -812,7 +889,7 @@ func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUints requires parameter value to exist to be bind to uint slice variable. Returns error when value does not exist +// MustUints requires parameter value to exist to bind to uint slice variable. Returns error when value does not exist func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -822,7 +899,7 @@ func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder { return b.boolValue(sourceParam, dest, false) } -// MustBool requires parameter value to exist to be bind to bool variable. Returns error when value does not exist +// MustBool requires parameter value to exist to bind to bool variable. Returns error when value does not exist func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder { return b.boolValue(sourceParam, dest, true) } @@ -887,7 +964,7 @@ func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder { return b.boolsValue(sourceParam, dest, false) } -// MustBools requires parameter values to exist to be bind to slice of bool variables. Returns error when values does not exist +// MustBools requires parameter values to exist to bind to slice of bool variables. Returns error when values does not exist func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder { return b.boolsValue(sourceParam, dest, true) } @@ -897,7 +974,7 @@ func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder { return b.floatValue(sourceParam, dest, 64, false) } -// MustFloat64 requires parameter value to exist to be bind to float64 variable. Returns error when value does not exist +// MustFloat64 requires parameter value to exist to bind to float64 variable. Returns error when value does not exist func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder { return b.floatValue(sourceParam, dest, 64, true) } @@ -907,7 +984,7 @@ func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder { return b.floatValue(sourceParam, dest, 32, false) } -// MustFloat32 requires parameter value to exist to be bind to float32 variable. Returns error when value does not exist +// MustFloat32 requires parameter value to exist to bind to float32 variable. Returns error when value does not exist func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder { return b.floatValue(sourceParam, dest, 32, true) } @@ -992,7 +1069,7 @@ func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder return b.floatsValue(sourceParam, dest, false) } -// MustFloat64s requires parameter values to exist to be bind to slice of float64 variables. Returns error when values does not exist +// MustFloat64s requires parameter values to exist to bind to slice of float64 variables. Returns error when values does not exist func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder { return b.floatsValue(sourceParam, dest, true) } @@ -1002,7 +1079,7 @@ func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder return b.floatsValue(sourceParam, dest, false) } -// MustFloat32s requires parameter values to exist to be bind to slice of float32 variables. Returns error when values does not exist +// MustFloat32s requires parameter values to exist to bind to slice of float32 variables. Returns error when values does not exist func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder { return b.floatsValue(sourceParam, dest, true) } @@ -1012,7 +1089,7 @@ func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) * return b.time(sourceParam, dest, layout, false) } -// MustTime requires parameter value to exist to be bind to time.Time variable. Returns error when value does not exist +// MustTime requires parameter value to exist to bind to time.Time variable. Returns error when value does not exist func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder { return b.time(sourceParam, dest, layout, true) } @@ -1043,7 +1120,7 @@ func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string return b.times(sourceParam, dest, layout, false) } -// MustTimes requires parameter values to exist to be bind to slice of time.Time variables. Returns error when values does not exist +// MustTimes requires parameter values to exist to bind to slice of time.Time variables. Returns error when values does not exist func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder { return b.times(sourceParam, dest, layout, true) } @@ -1084,7 +1161,7 @@ func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBi return b.duration(sourceParam, dest, false) } -// MustDuration requires parameter value to exist to be bind to time.Duration variable. Returns error when value does not exist +// MustDuration requires parameter value to exist to bind to time.Duration variable. Returns error when value does not exist func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder { return b.duration(sourceParam, dest, true) } @@ -1115,7 +1192,7 @@ func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *Valu return b.durationsValue(sourceParam, dest, false) } -// MustDurations requires parameter values to exist to be bind to slice of time.Duration variables. Returns error when values does not exist +// MustDurations requires parameter values to exist to bind to slice of time.Duration variables. Returns error when values does not exist func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder { return b.durationsValue(sourceParam, dest, true) } @@ -1161,10 +1238,10 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim // Note: // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, false, false) + return b.unixTime(sourceParam, dest, false, time.Second) } -// MustUnixTime requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding // to the given Unix time). Returns error when value does not exist. // // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 @@ -1172,10 +1249,31 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder // Note: // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, true, false) + return b.unixTime(sourceParam, dest, true, time.Second) } -// UnixTimeNano binds parameter to time.Time variable (in local Time corresponding to the given Unix time in nano second precision). +// UnixTimeMilli binds parameter to time.Time variable (in local time corresponding to the given Unix time in millisecond precision). +// +// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, false, time.Millisecond) +} + +// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding +// to the given Unix time in millisecond precision). Returns error when value does not exist. +// +// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, true, time.Millisecond) +} + +// UnixTimeNano binds parameter to time.Time variable (in local time corresponding to the given Unix time in nanosecond precision). // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 // Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 @@ -1185,10 +1283,10 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal // * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, false, true) + return b.unixTime(sourceParam, dest, false, time.Nanosecond) } -// MustUnixTimeNano requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding // to the given Unix time value in nano second precision). Returns error when value does not exist. // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 @@ -1199,10 +1297,10 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal // * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, true, true) + return b.unixTime(sourceParam, dest, true, time.Nanosecond) } -func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, isNano bool) *ValueBinder { +func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, precision time.Duration) *ValueBinder { if b.failFast && b.errors != nil { return b } @@ -1221,10 +1319,13 @@ func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExi return b } - if isNano { - *dest = time.Unix(0, n) - } else { + switch precision { + case time.Second: *dest = time.Unix(n, 0) + case time.Millisecond: + *dest = time.Unix(n/1e3, (n%1e3)*1e6) // TODO: time.UnixMilli(n) exists since Go1.17 switch to that when min version allows + case time.Nanosecond: + *dest = time.Unix(0, n) } return b } diff --git a/binder_test.go b/binder_test.go index f57da32d..3c17057c 100644 --- a/binder_test.go +++ b/binder_test.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/stretchr/testify/assert" "io" + "math/big" "net/http" "net/http/httptest" "strconv" @@ -55,7 +56,7 @@ func TestBindingError_Error(t *testing.T) { func TestBindingError_ErrorJSON(t *testing.T) { err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) - resp, err := json.Marshal(err) + resp, _ := json.Marshal(err) assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp)) } @@ -2188,6 +2189,188 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) { } } +func TestValueBinder_JSONUnmarshaler(t *testing.T) { + example := big.NewInt(999) + + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue big.Int + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: big.Int{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: big.Int{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=xxx", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest big.Int + var err error + if tc.whenMust { + err = b.MustJSONUnmarshaler("param", &dest).BindError() + } else { + err = b.JSONUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_TextUnmarshaler(t *testing.T) { + example := big.NewInt(999) + + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue big.Int + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: big.Int{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: big.Int{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=xxx", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest big.Int + var err error + if tc.whenMust { + err = b.MustTextUnmarshaler("param", &dest).BindError() + } else { + err = b.TextUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestValueBinder_BindWithDelimiter_types(t *testing.T) { var testCases = []struct { name string @@ -2530,6 +2713,97 @@ func TestValueBinder_UnixTime(t *testing.T) { } } +func TestValueBinder_UnixTimeMilli(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140 + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Time + expectError string + }{ + { + name: "ok, binds value, unix time in milliseconds", + whenURL: "/search?param=1647184410140¶m=1647184410199", + expectValue: exampleTime, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: time.Time{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1647184410140¶m=1647184410199", + expectValue: exampleTime, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: time.Time{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustUnixTimeMilli("param", &dest).BindError() + } else { + err = b.UnixTimeMilli("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) + assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestValueBinder_UnixTimeNano(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603 exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789 diff --git a/context.go b/context.go index a397fba7..5f4a2010 100644 --- a/context.go +++ b/context.go @@ -55,6 +55,16 @@ type Context interface { // PathParam returns path parameter by name. PathParam(name string) string + // PathParamDefault returns the path parameter or default value for the provided name. + // + // Notes for DefaultRouter implementation: + // Path parameter could be empty for cases like that: + // * route `/release-:version/bin` and request URL is `/release-/bin` + // * route `/api/:version/image.jpg` and request URL is `/api//image.jpg` + // but not when path parameter is last part of route path + // * route `/download/file.:ext` will not match request `/download/file.` + PathParamDefault(name string, defaultValue string) string + // PathParams returns path parameter values. PathParams() PathParams @@ -176,6 +186,9 @@ type Context interface { Redirect(code int, url string) error // Echo returns the `Echo` instance. + // + // WARNING: Remember that Echo public fields and methods are coroutine safe ONLY when you are NOT mutating them + // anywhere in your code after Echo server has started. Echo() *Echo } @@ -379,7 +392,10 @@ func (c *DefaultContext) PathParam(name string) string { return c.pathParams.Get(name, "") } -// PathParamDefault does not exist as expecting empty path param makes no sense +// PathParamDefault returns the path parameter or default value for the provided name. +func (c *DefaultContext) PathParamDefault(name, defaultValue string) string { + return c.pathParams.Get(name, defaultValue) +} // PathParams returns path parameter values. func (c *DefaultContext) PathParams() PathParams { @@ -406,7 +422,8 @@ func (c *DefaultContext) QueryParam(name string) string { } // QueryParamDefault returns the query param or default value for the provided name. -// Note: QueryParamDefault does not distinguish if form had no value by that name or value was empty string +// Note: QueryParamDefault does not distinguish if query had no value by that name or value was empty string +// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamDefault("search", "1")` func (c *DefaultContext) QueryParamDefault(name, defaultValue string) string { value := c.QueryParam(name) if value == "" { diff --git a/context_test.go b/context_test.go index dca680f9..0df10539 100644 --- a/context_test.go +++ b/context_test.go @@ -448,22 +448,142 @@ func TestContextCookie(t *testing.T) { assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } -func TestContextPathParam(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, nil).(*DefaultContext) - - params := &PathParams{ - {Name: "uid", Value: "101"}, - {Name: "fid", Value: "501"}, +func TestContext_PathParams(t *testing.T) { + var testCases = []struct { + name string + given *PathParams + expect PathParams + }{ + { + name: "param exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + expect: PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + }, + { + name: "params is empty", + given: &PathParams{}, + expect: PathParams{}, + }, } - // ParamNames - c.pathParams = params - assert.EqualValues(t, *params, c.PathParams()) - // Param - assert.Equal(t, "501", c.PathParam("fid")) - assert.Equal(t, "", c.PathParam("undefined")) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) + + c.(RoutableContext).SetRawPathParams(tc.given) + + assert.EqualValues(t, tc.expect, c.PathParams()) + }) + } +} + +func TestContext_PathParam(t *testing.T) { + var testCases = []struct { + name string + given *PathParams + whenParamName string + expect string + }{ + { + name: "param exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + expect: "101", + }, + { + name: "multiple same param values exists - return first", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "uid", Value: "202"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + expect: "101", + }, + { + name: "param does not exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + }, + whenParamName: "nope", + expect: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) + + c.(RoutableContext).SetRawPathParams(tc.given) + + assert.EqualValues(t, tc.expect, c.PathParam(tc.whenParamName)) + }) + } +} + +func TestContext_PathParamDefault(t *testing.T) { + var testCases = []struct { + name string + given *PathParams + whenParamName string + whenDefaultValue string + expect string + }{ + { + name: "param exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + whenDefaultValue: "999", + expect: "101", + }, + { + name: "param exists and is empty", + given: &PathParams{ + {Name: "uid", Value: ""}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + whenDefaultValue: "999", + expect: "", // <-- this is different from QueryParamDefault behaviour + }, + { + name: "param does not exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + }, + whenParamName: "nope", + whenDefaultValue: "999", + expect: "999", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) + + c.(RoutableContext).SetRawPathParams(tc.given) + + assert.EqualValues(t, tc.expect, c.PathParamDefault(tc.whenParamName, tc.whenDefaultValue)) + }) + } } func TestContextGetAndSetParam(t *testing.T) { @@ -568,27 +688,129 @@ func TestContextFormValue(t *testing.T) { assert.Error(t, err) } -func TestContextQueryParam(t *testing.T) { - q := make(url.Values) - q.Set("name", "Jon Snow") - q.Set("email", "jon@labstack.com") - req := httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) - e := New() - c := e.NewContext(req, nil) +func TestContext_QueryParams(t *testing.T) { + var testCases = []struct { + name string + givenURL string + expect url.Values + }{ + { + name: "multiple values in url", + givenURL: "/?test=1&test=2&email=jon%40labstack.com", + expect: url.Values{ + "test": []string{"1", "2"}, + "email": []string{"jon@labstack.com"}, + }, + }, + { + name: "single value in url", + givenURL: "/?nope=1", + expect: url.Values{ + "nope": []string{"1"}, + }, + }, + { + name: "no query params in url", + givenURL: "/?", + expect: url.Values{}, + }, + } - // QueryParam - assert.Equal(t, "Jon Snow", c.QueryParam("name")) - assert.Equal(t, "jon@labstack.com", c.QueryParam("email")) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) - // QueryParamDefault - assert.Equal(t, "Jon Snow", c.QueryParamDefault("name", "nope")) - assert.Equal(t, "default", c.QueryParamDefault("missing", "default")) + assert.Equal(t, tc.expect, c.QueryParams()) + }) + } +} - // QueryParams - assert.Equal(t, url.Values{ - "name": []string{"Jon Snow"}, - "email": []string{"jon@labstack.com"}, - }, c.QueryParams()) +func TestContext_QueryParam(t *testing.T) { + var testCases = []struct { + name string + givenURL string + whenParamName string + expect string + }{ + { + name: "value exists in url", + givenURL: "/?test=1", + whenParamName: "test", + expect: "1", + }, + { + name: "multiple values exists in url", + givenURL: "/?test=9&test=8", + whenParamName: "test", + expect: "9", // <-- first value in returned + }, + { + name: "value does not exists in url", + givenURL: "/?nope=1", + whenParamName: "test", + expect: "", + }, + { + name: "value is empty in url", + givenURL: "/?test=", + whenParamName: "test", + expect: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) + + assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName)) + }) + } +} + +func TestContext_QueryParamDefault(t *testing.T) { + var testCases = []struct { + name string + givenURL string + whenParamName string + whenDefaultValue string + expect string + }{ + { + name: "value exists in url", + givenURL: "/?test=1", + whenParamName: "test", + whenDefaultValue: "999", + expect: "1", + }, + { + name: "value does not exists in url", + givenURL: "/?nope=1", + whenParamName: "test", + whenDefaultValue: "999", + expect: "999", + }, + { + name: "value is empty in url", + givenURL: "/?test=", + whenParamName: "test", + whenDefaultValue: "999", + expect: "999", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) + + assert.Equal(t, tc.expect, c.QueryParamDefault(tc.whenParamName, tc.whenDefaultValue)) + }) + } } func TestContextFormFile(t *testing.T) { diff --git a/echo.go b/echo.go index 29fa2281..c43d26e2 100644 --- a/echo.go +++ b/echo.go @@ -49,12 +49,15 @@ import ( "os" "os/signal" "path/filepath" + "runtime" "strings" "sync" ) // Echo is the top-level framework instance. -// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics. +// +// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics. This is very likely +// to happen when you access Echo instances through Context.Echo() method. type Echo struct { // premiddleware are middlewares that are run for every request before routing is done premiddleware []MiddlewareFunc @@ -66,8 +69,8 @@ type Echo struct { routerCreator func(e *Echo) Router contextPool sync.Pool - // contextPathParamAllocSize holds maximum parameter count for all added routes. This is necessary info for context - // creation time so we can allocate path parameter values slice. + // contextPathParamAllocSize holds maximum parameter count for all added routes. This is necessary info at context + // creation moment so we can allocate path parameter values slice with correct size. contextPathParamAllocSize int // NewContextFunc allows using custom context implementations, instead of default *echo.context @@ -150,6 +153,8 @@ const ( PROPFIND = "PROPFIND" // REPORT Method can be used to get information about a resource, see rfc 3253 REPORT = "REPORT" + // RouteNotFound is special method type for routes handling "route not found" (404) cases + RouteNotFound = "echo_route_not_found" ) // Headers @@ -181,12 +186,14 @@ const ( HeaderXForwardedSsl = "X-Forwarded-Ssl" HeaderXUrlScheme = "X-Url-Scheme" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" - HeaderXRealIP = "X-Real-IP" - HeaderXRequestID = "X-Request-ID" - HeaderXCorrelationID = "X-Correlation-ID" + HeaderXRealIP = "X-Real-Ip" + HeaderXRequestID = "X-Request-Id" + HeaderXCorrelationID = "X-Correlation-Id" HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" HeaderOrigin = "Origin" + HeaderCacheControl = "Cache-Control" + HeaderConnection = "Connection" // Access control HeaderAccessControlRequestMethod = "Access-Control-Request-Method" @@ -403,6 +410,16 @@ func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo return e.Add(http.MethodTrace, path, h, m...) } +// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases) +// for current request URL. +// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with +// wildcard/match-any character (`/*`, `/download/*` etc). +// +// Example: `e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { + return e.Add(RouteNotFound, path, h, m...) +} + // Any registers a new route for all supported HTTP methods and path with matching handler // in the router with optional route-level middleware. Panics on error. func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { @@ -707,20 +724,26 @@ func newDefaultFS() *defaultFS { dir, _ := os.Getwd() return &defaultFS{ prefix: dir, - fs: os.DirFS(dir), + fs: nil, } } func (fs defaultFS) Open(name string) (fs.File, error) { + if fs.fs == nil { + return os.Open(name) + } return fs.fs.Open(name) } func subFS(currentFs fs.FS, root string) (fs.FS, error) { root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows if dFS, ok := currentFs.(*defaultFS); ok { - // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS to - // allow cases when root is given as `../somepath` which is not valid for fs.FS - root = filepath.Join(dFS.prefix, root) + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. + // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we + // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs + if isRelativePath(root) { + root = filepath.Join(dFS.prefix, root) + } return &defaultFS{ prefix: root, fs: os.DirFS(root), @@ -729,6 +752,21 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) { return fs.Sub(currentFs, root) } +func isRelativePath(path string) bool { + if path == "" { + return true + } + if path[0] == '/' { + return false + } + if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 { + // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names + // https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats + return false + } + return true +} + // MustSubFS creates sub FS from current filesystem or panic on failure. // Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. // diff --git a/echo_test.go b/echo_test.go index ba21014e..b75bd253 100644 --- a/echo_test.go +++ b/echo_test.go @@ -336,16 +336,54 @@ func TestEchoStaticRedirectIndex(t *testing.T) { } func TestEchoFile(t *testing.T) { - e := New() - ri := e.File("/walle", "_fixture/images/walle.png") - assert.Equal(t, http.MethodGet, ri.Method()) - assert.Equal(t, "/walle", ri.Path()) - assert.Equal(t, "GET:/walle", ri.Name()) - assert.Nil(t, ri.Params()) + var testCases = []struct { + name string + givenPath string + givenFile string + whenPath string + expectCode int + expectStartsWith string + }{ + { + name: "ok", + givenPath: "/walle", + givenFile: "_fixture/images/walle.png", + whenPath: "/walle", + expectCode: http.StatusOK, + expectStartsWith: string([]byte{0x89, 0x50, 0x4e}), + }, + { + name: "ok with relative path", + givenPath: "/", + givenFile: "./go.mod", + whenPath: "/", + expectCode: http.StatusOK, + expectStartsWith: "module github.com/labstack/echo/v", + }, + { + name: "nok file does not exist", + givenPath: "/", + givenFile: "./this-file-does-not-exist", + whenPath: "/", + expectCode: http.StatusNotFound, + expectStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } - c, b := request(http.MethodGet, "/walle", e) - assert.Equal(t, http.StatusOK, c) - assert.NotEmpty(t, b) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() // we are using echo.defaultFS instance + e.File(tc.givenPath, tc.givenFile) + + c, b := request(http.MethodGet, tc.whenPath, e) + assert.Equal(t, tc.expectCode, c) + + if len(b) > len(tc.expectStartsWith) { + b = b[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, b) + }) + } } func TestEchoMiddleware(t *testing.T) { @@ -880,6 +918,70 @@ func TestEchoGroup(t *testing.T) { assert.Equal(t, "023", buf.String()) } +func TestEcho_RouteNotFound(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectCode int + }{ + { + name: "404, route to static not found handler /a/c/xx", + whenURL: "/a/c/xx", + expectRoute: "GET /a/c/xx", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to path param not found handler /a/:file", + whenURL: "/a/echo.exe", + expectRoute: "GET /a/:file", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to any not found handler /*", + whenURL: "/b/echo.exe", + expectRoute: "GET /*", + expectCode: http.StatusNotFound, + }, + { + name: "200, route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "GET /a/c/df", + expectCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + e.GET("/", okHandler) + e.GET("/a/c/df", okHandler) + e.GET("/a/b*", okHandler) + e.PUT("/*", okHandler) + + e.RouteNotFound("/a/c/xx", notFoundHandler) // static + e.RouteNotFound("/a/:file", notFoundHandler) // param + e.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) + }) + } +} + func TestEchoNotFound(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/files", nil) diff --git a/group.go b/group.go index 4f04a73a..b9df5af9 100644 --- a/group.go +++ b/group.go @@ -157,6 +157,13 @@ func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo return g.Add(http.MethodGet, path, handler, middleware...) } +// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group. +// +// Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { + return g.Add(RouteNotFound, path, h, m...) +} + // Add implements `Echo#Add()` for sub-routes within the Group. Panics on error. func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { ri, err := g.AddRoute(Route{ diff --git a/group_test.go b/group_test.go index 3914c0bd..bd215726 100644 --- a/group_test.go +++ b/group_test.go @@ -303,6 +303,71 @@ func TestGroup_TRACE(t *testing.T) { assert.Equal(t, `OK`, body) } +func TestGroup_RouteNotFound(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectCode int + }{ + { + name: "404, route to static not found handler /group/a/c/xx", + whenURL: "/group/a/c/xx", + expectRoute: "GET /group/a/c/xx", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to path param not found handler /group/a/:file", + whenURL: "/group/a/echo.exe", + expectRoute: "GET /group/a/:file", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to any not found handler /group/*", + whenURL: "/group/b/echo.exe", + expectRoute: "GET /group/*", + expectCode: http.StatusNotFound, + }, + { + name: "200, route /group/a/c/df to /group/a/c/df", + whenURL: "/group/a/c/df", + expectRoute: "GET /group/a/c/df", + expectCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/group") + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + g.GET("/", okHandler) + g.GET("/a/c/df", okHandler) + g.GET("/a/b*", okHandler) + g.PUT("/*", okHandler) + + g.RouteNotFound("/a/c/xx", notFoundHandler) // static + g.RouteNotFound("/a/:file", notFoundHandler) // param + g.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) + }) + } +} + func TestGroup_Any(t *testing.T) { e := New() diff --git a/ip.go b/ip.go index 39cb421f..46d464cf 100644 --- a/ip.go +++ b/ip.go @@ -6,6 +6,130 @@ import ( "strings" ) +/** +By: https://github.com/tmshn (See: https://github.com/labstack/echo/pull/1478 , https://github.com/labstack/echox/pull/134 ) +Source: https://echo.labstack.com/guide/ip-address/ + +IP address plays fundamental role in HTTP; it's used for access control, auditing, geo-based access analysis and more. +Echo provides handy method [`Context#RealIP()`](https://godoc.org/github.com/labstack/echo#Context) for that. + +However, it is not trivial to retrieve the _real_ IP address from requests especially when you put L7 proxies before the application. +In such situation, _real_ IP needs to be relayed on HTTP layer from proxies to your app, but you must not trust HTTP headers unconditionally. +Otherwise, you might give someone a chance of deceiving you. **A security risk!** + +To retrieve IP address reliably/securely, you must let your application be aware of the entire architecture of your infrastructure. +In Echo, this can be done by configuring `Echo#IPExtractor` appropriately. +This guides show you why and how. + +> Note: if you dont' set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice. + +Let's start from two questions to know the right direction: + +1. Do you put any HTTP (L7) proxy in front of the application? + - It includes both cloud solutions (such as AWS ALB or GCP HTTP LB) and OSS ones (such as Nginx, Envoy or Istio ingress gateway). +2. If yes, what HTTP header do your proxies use to pass client IP to the application? + +## Case 1. With no proxy + +If you put no proxy (e.g.: directory facing to the internet), all you need to (and have to) see is IP address from network layer. +Any HTTP header is untrustable because the clients have full control what headers to be set. + +In this case, use `echo.ExtractIPDirect()`. + +```go +e.IPExtractor = echo.ExtractIPDirect() +``` + +## Case 2. With proxies using `X-Forwarded-For` header + +[`X-Forwared-For` (XFF)](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) is the popular header +to relay clients' IP addresses. +At each hop on the proxies, they append the request IP address at the end of the header. + +Following example diagram illustrates this behavior. + +```text +┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ +│ "Origin" │───────────>│ Proxy 1 │───────────>│ Proxy 2 │───────────>│ Your app │ +│ (IP: a) │ │ (IP: b) │ │ (IP: c) │ │ │ +└──────────┘ └──────────┘ └──────────┘ └──────────┘ + +Case 1. +XFF: "" "a" "a, b" + ~~~~~~ +Case 2. +XFF: "x" "x, a" "x, a, b" + ~~~~~~~~~ + ↑ What your app will see +``` + +In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is +configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructre". +In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`. + +In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`. + +```go +e.IPExtractor = echo.ExtractIPFromXFFHeader() +``` + +By default, it trusts internal IP addresses (loopback, link-local unicast, private-use and unique local address +from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and +[RFC4193](https://tools.ietf.org/html/rfc4193)). +To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. + +E.g.: + +```go +e.IPExtractor = echo.ExtractIPFromXFFHeader( + TrustLinkLocal(false), + TrustIPRanges(lbIPRange), +) +``` + +- Ref: https://godoc.org/github.com/labstack/echo#TrustOption + +## Case 3. With proxies using `X-Real-IP` header + +`X-Real-IP` is another HTTP header to relay clients' IP addresses, but it carries only one address unlike XFF. + +If your proxies set this header, use `ExtractIPFromRealIPHeader(...TrustOption)`. + +```go +e.IPExtractor = echo.ExtractIPFromRealIPHeader() +``` + +Again, it trusts internal IP addresses by default (loopback, link-local unicast, private-use and unique local address +from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and +[RFC4193](https://tools.ietf.org/html/rfc4193)). +To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. + +- Ref: https://godoc.org/github.com/labstack/echo#TrustOption + +> **Never forget** to configure the outermost proxy (i.e.; at the edge of your infrastructure) **not to pass through incoming headers**. +> Otherwise there is a chance of fraud, as it is what clients can control. + +## About default behavior + +In default behavior, Echo sees all of first XFF header, X-Real-IP header and IP from network layer. + +As you might already notice, after reading this article, this is not good. +Sole reason this is default is just backward compatibility. + +## Private IP ranges + +See: https://en.wikipedia.org/wiki/Private_network + +Private IPv4 address ranges (RFC 1918): +* 10.0.0.0 – 10.255.255.255 (24-bit block) +* 172.16.0.0 – 172.31.255.255 (20-bit block) +* 192.168.0.0 – 192.168.255.255 (16-bit block) + +Private IPv6 address ranges: +* fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) + +*/ + type ipChecker struct { trustLoopback bool trustLinkLocal bool @@ -52,6 +176,7 @@ func newIPChecker(configs []TrustOption) *ipChecker { return checker } +// Go1.16+ added `ip.IsPrivate()` but until that use this implementation func isPrivateIPRange(ip net.IP) bool { if ip4 := ip.To4(); ip4 != nil { return ip4[0] == 10 || @@ -87,10 +212,12 @@ type IPExtractor func(*http.Request) string // ExtractIPDirect extracts IP address using actual IP address. // Use this if your server faces to internet directory (i.e.: uses no proxy). func ExtractIPDirect() IPExtractor { - return func(req *http.Request) string { - ra, _, _ := net.SplitHostPort(req.RemoteAddr) - return ra - } + return extractIP +} + +func extractIP(req *http.Request) string { + ra, _, _ := net.SplitHostPort(req.RemoteAddr) + return ra } // ExtractIPFromRealIPHeader extracts IP address using x-real-ip header. @@ -98,14 +225,13 @@ func ExtractIPDirect() IPExtractor { func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { - directIP := ExtractIPDirect()(req) realIP := req.Header.Get(HeaderXRealIP) if realIP != "" { - if ip := net.ParseIP(directIP); ip != nil && checker.trust(ip) { + if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { return realIP } } - return directIP + return extractIP(req) } } @@ -115,7 +241,7 @@ func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { - directIP := ExtractIPDirect()(req) + directIP := extractIP(req) xffs := req.Header[HeaderXForwardedFor] if len(xffs) == 0 { return directIP diff --git a/ip_test.go b/ip_test.go index 5acc1179..755900d3 100644 --- a/ip_test.go +++ b/ip_test.go @@ -1,235 +1,606 @@ package echo import ( + "github.com/stretchr/testify/assert" "net" "net/http" - "strings" "testing" - - testify "github.com/stretchr/testify/assert" ) -const ( - // For RemoteAddr - ipForRemoteAddrLoopback = "127.0.0.1" // From 127.0.0.0/8 - sampleRemoteAddrLoopback = ipForRemoteAddrLoopback + ":8080" - ipForRemoteAddrExternal = "203.0.113.1" - sampleRemoteAddrExternal = ipForRemoteAddrExternal + ":8080" - // For x-real-ip - ipForRealIP = "203.0.113.10" - // For XFF - ipForXFF1LinkLocal = "169.254.0.101" // From 169.254.0.0/16 - ipForXFF2Private = "192.168.0.102" // From 192.168.0.0/16 - ipForXFF3External = "2001:db8::103" - ipForXFF4Private = "fc00::104" // From fc00::/7 - ipForXFF5External = "198.51.100.105" - ipForXFF6External = "192.0.2.106" - ipForXFFBroken = "this.is.broken.lol" - // keys for test cases - ipTestReqKeyNoHeader = "no header" - ipTestReqKeyRealIPExternal = "x-real-ip; remote addr external" - ipTestReqKeyRealIPInternal = "x-real-ip; remote addr internal" - ipTestReqKeyRealIPAndXFFExternal = "x-real-ip and xff; remote addr external" - ipTestReqKeyRealIPAndXFFInternal = "x-real-ip and xff; remote addr internal" - ipTestReqKeyXFFExternal = "xff; remote addr external" - ipTestReqKeyXFFInternal = "xff; remote addr internal" - ipTestReqKeyBrokenXFF = "broken xff" -) - -var ( - sampleXFF = strings.Join([]string{ - ipForXFF6External, ipForXFF5External, ipForXFF4Private, ipForXFF3External, ipForXFF2Private, ipForXFF1LinkLocal, - }, ", ") - - requests = map[string]*http.Request{ - ipTestReqKeyNoHeader: &http.Request{ - RemoteAddr: sampleRemoteAddrExternal, - }, - ipTestReqKeyRealIPExternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - }, - RemoteAddr: sampleRemoteAddrExternal, - }, - ipTestReqKeyRealIPInternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - }, - RemoteAddr: sampleRemoteAddrLoopback, - }, - ipTestReqKeyRealIPAndXFFExternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - HeaderXForwardedFor: []string{sampleXFF}, - }, - RemoteAddr: sampleRemoteAddrExternal, - }, - ipTestReqKeyRealIPAndXFFInternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - HeaderXForwardedFor: []string{sampleXFF}, - }, - RemoteAddr: sampleRemoteAddrLoopback, - }, - ipTestReqKeyXFFExternal: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{sampleXFF}, - }, - RemoteAddr: sampleRemoteAddrExternal, - }, - ipTestReqKeyXFFInternal: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{sampleXFF}, - }, - RemoteAddr: sampleRemoteAddrLoopback, - }, - ipTestReqKeyBrokenXFF: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{ipForXFFBroken + ", " + ipForXFF1LinkLocal}, - }, - RemoteAddr: sampleRemoteAddrLoopback, - }, +func mustParseCIDR(s string) *net.IPNet { + _, IPNet, err := net.ParseCIDR(s) + if err != nil { + panic(err) } -) + return IPNet +} -func TestExtractIP(t *testing.T) { - _, ipv4AllRange, _ := net.ParseCIDR("0.0.0.0/0") - _, ipv6AllRange, _ := net.ParseCIDR("::/0") - _, ipForXFF3ExternalRange, _ := net.ParseCIDR(ipForXFF3External + "/48") - _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR(ipForRemoteAddrExternal + "/24") - - tests := map[string]*struct { - extractor IPExtractor - expectedIPs map[string]string +func TestIPChecker_TrustOption(t *testing.T) { + var testCases = []struct { + name string + givenOptions []TrustOption + whenIP string + expect bool }{ - "ExtractIPDirect": { - ExtractIPDirect(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + { + name: "ip is within trust range, trusts additional private IPV6 network", + givenOptions: []TrustOption{ + TrustLoopback(false), + TrustLinkLocal(false), + TrustPrivateNet(false), + // this is private IPv6 ip + // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 + // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 + // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 + // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff + TrustIPRange(mustParseCIDR("2001:db8::103/48")), }, + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, }, - "ExtractIPFromRealIPHeader(default)": { - ExtractIPFromRealIPHeader(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForRealIP, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromRealIPHeader(trust only direct-facing proxy)": { - ExtractIPFromRealIPHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRealIP, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromRealIPHeader(trust direct-facing proxy)": { - ExtractIPFromRealIPHeader(TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRealIP, - ipTestReqKeyRealIPInternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFExternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFInternal: ipForRealIP, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(default)": { - ExtractIPFromXFFHeader(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForXFF3External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust only direct-facing proxy)": { - ExtractIPFromXFFHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF1LinkLocal, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForXFF1LinkLocal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust direct-facing proxy)": { - ExtractIPFromXFFHeader(TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF3External, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External, - ipTestReqKeyXFFExternal: ipForXFF3External, - ipTestReqKeyXFFInternal: ipForXFF3External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust everything)": { - // This is similar to legacy behavior, but ignores x-real-ip header. - ExtractIPFromXFFHeader(TrustIPRange(ipv4AllRange), TrustIPRange(ipv6AllRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF6External, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF6External, - ipTestReqKeyXFFExternal: ipForXFF6External, - ipTestReqKeyXFFInternal: ipForXFF6External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust ipForXFF3External)": { - // This trusts private network also after "additional" trust ranges unlike `TrustNProxies(1)` doesn't - ExtractIPFromXFFHeader(TrustIPRange(ipForXFF3ExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF5External, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForXFF5External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + { + name: "ip is within trust range, trusts additional private IPV6 network", + givenOptions: []TrustOption{ + TrustIPRange(mustParseCIDR("2001:db8::103/48")), }, + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, }, } - for name, test := range tests { - t.Run(name, func(t *testing.T) { - assert := testify.New(t) - for key, req := range requests { - actual := test.extractor(req) - expected := test.expectedIPs[key] - assert.Equal(expected, actual, "Request: %s", key) - } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker(tc.givenOptions) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustIPRange(t *testing.T) { + var testCases = []struct { + name string + givenRange string + whenIP string + expect bool + }{ + { + name: "ip is within trust range, IPV6 network range", + // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 + // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 + // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 + // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff + givenRange: "2001:db8::103/48", + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, + }, + { + name: "ip is outside (upper bounds) of trust range, IPV6 network range", + givenRange: "2001:db8::103/48", + whenIP: "2001:0db8:0001:0000:0000:0000:0000:0000", + expect: false, + }, + { + name: "ip is outside (lower bounds) of trust range, IPV6 network range", + givenRange: "2001:db8::103/48", + whenIP: "2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff", + expect: false, + }, + { + name: "ip is within trust range, IPV4 network range", + // CIDR Notation: 8.8.8.8/24 + // Address: 8.8.8.8 + // Range start: 8.8.8.0 + // Range end: 8.8.8.255 + givenRange: "8.8.8.0/24", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "ip is within trust range, IPV4 network range", + // CIDR Notation: 8.8.8.8/24 + // Address: 8.8.8.8 + // Range start: 8.8.8.0 + // Range end: 8.8.8.255 + givenRange: "8.8.8.0/24", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "ip is outside (upper bounds) of trust range, IPV4 network range", + givenRange: "8.8.8.0/24", + whenIP: "8.8.9.0", + expect: false, + }, + { + name: "ip is outside (lower bounds) of trust range, IPV4 network range", + givenRange: "8.8.8.0/24", + whenIP: "8.8.7.255", + expect: false, + }, + { + name: "public ip, trust everything in IPV4 network range", + givenRange: "0.0.0.0/0", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "internal ip, trust everything in IPV4 network range", + givenRange: "0.0.0.0/0", + whenIP: "127.0.10.1", + expect: true, + }, + { + name: "public ip, trust everything in IPV6 network range", + givenRange: "::/0", + whenIP: "2a00:1450:4026:805::200e", + expect: true, + }, + { + name: "internal ip, trust everything in IPV6 network range", + givenRange: "::/0", + whenIP: "0:0:0:0:0:0:0:1", + expect: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cidr := mustParseCIDR(tc.givenRange) + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustLinkLocal(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustIPRange(cidr), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustPrivateNet(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "do not trust public IPv4 address", + whenIP: "8.8.8.8", + expect: false, + }, + { + name: "do not trust public IPv6 address", + whenIP: "2a00:1450:4026:805::200e", + expect: false, + }, + + { // Class A: 10.0.0.0 — 10.255.255.255 + name: "do not trust IPv4 just outside of class A (lower bounds)", + whenIP: "9.255.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class A (upper bounds)", + whenIP: "11.0.0.0", + expect: false, + }, + { + name: "trust IPv4 of class A (lower bounds)", + whenIP: "10.0.0.0", + expect: true, + }, + { + name: "trust IPv4 of class A (upper bounds)", + whenIP: "10.255.255.255", + expect: true, + }, + + { // Class B: 172.16.0.0 — 172.31.255.255 + name: "do not trust IPv4 just outside of class B (lower bounds)", + whenIP: "172.15.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class B (upper bounds)", + whenIP: "172.32.0.0", + expect: false, + }, + { + name: "trust IPv4 of class B (lower bounds)", + whenIP: "172.16.0.0", + expect: true, + }, + { + name: "trust IPv4 of class B (upper bounds)", + whenIP: "172.31.255.255", + expect: true, + }, + + { // Class C: 192.168.0.0 — 192.168.255.255 + name: "do not trust IPv4 just outside of class C (lower bounds)", + whenIP: "192.167.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class C (upper bounds)", + whenIP: "192.169.0.0", + expect: false, + }, + { + name: "trust IPv4 of class C (lower bounds)", + whenIP: "192.168.0.0", + expect: true, + }, + { + name: "trust IPv4 of class C (upper bounds)", + whenIP: "192.168.255.255", + expect: true, + }, + + { // fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) + // splits the address block in two equally sized halves, fc00::/8 and fd00::/8. + // https://en.wikipedia.org/wiki/Unique_local_address + name: "trust IPv6 private address", + whenIP: "fdfc:3514:2cb3:4bd5::", + expect: true, + }, + { + name: "do not trust IPv6 just out of /fd (upper bounds)", + whenIP: "/fe00:0000:0000:0000:0000", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustLinkLocal(false), // disable to avoid interference + + TrustPrivateNet(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustLinkLocal(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "trust link local IPv4 address (lower bounds)", + whenIP: "169.254.0.0", + expect: true, + }, + { + name: "trust link local IPv4 address (upper bounds)", + whenIP: "169.254.255.255", + expect: true, + }, + { + name: "do not trust link local IPv4 address (outside of lower bounds)", + whenIP: "169.253.255.255", + expect: false, + }, + { + name: "do not trust link local IPv4 address (outside of upper bounds)", + whenIP: "169.255.0.0", + expect: false, + }, + { + name: "trust link local IPv6 address ", + whenIP: "fe80::1", + expect: true, + }, + { + name: "do not trust link local IPv6 address ", + whenIP: "fec0::1", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustLinkLocal(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustLoopback(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "trust IPv4 as localhost", + whenIP: "127.0.0.1", + expect: true, + }, + { + name: "trust IPv6 as localhost", + whenIP: "::1", + expect: true, + }, + { + name: "do not trust public ip as localhost", + whenIP: "8.8.8.8", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLinkLocal(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustLoopback(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestExtractIPDirect(t *testing.T) { + var testCases = []struct { + name string + whenRequest http.Request + expectIP string + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from internal IP and has Real-IP header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request is from external IP and has XFF + Real-IP header, extractor still extracts external IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from internal IP and has XFF + Real-IP header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"127.0.0.1"}, + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request is from external IP and has XFF header, extractor still extracts external IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from internal IP and has XFF header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request is from internal IP and has INVALID XFF header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"this.is.broken.lol, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPDirect()(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) + }) + } +} + +func TestExtractIPFromRealIPHeader(t *testing.T) { + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + + var testCases = []struct { + name string + givenTrustOptions []TrustOption + whenRequest http.Request + expectIP string + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has INVALID external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"xxx.yyy.zzz.ccc"}, // <-- this is invalid + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.199", + }, + { + name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, + HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.199", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPFromRealIPHeader(tc.givenTrustOptions...)(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) + }) + } +} + +func TestExtractIPFromXFFHeader(t *testing.T) { + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + + var testCases = []struct { + name string + givenTrustOptions []TrustOption + whenRequest http.Request + expectIP string + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request has INVALID external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"xxx.yyy.zzz.ccc, 127.0.0.2"}, // <-- this is invalid + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"127.0.0.3, 127.0.0.2, 127.0.0.1"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.3", + }, + { + name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"203.0.113.199"}, // <-- this is untrusted + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", + givenTrustOptions: []TrustOption{ + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + // from request its seems that request has been proxied through 6 servers. + // 1) 203.0.1.100 (this is external IP set by 203.0.100.100 which we do not trust - could be spoofed) + // 2) 203.0.100.100 (this is outside of our network but set by 203.0.113.199 which we trust to set correct IPs) + // 3) 203.0.113.199 (we trust, for example maybe our proxy from some other office) + // 4) 192.168.1.100 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products) + // 5) 127.0.0.1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing) + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"203.0.1.100, 203.0.100.100, 203.0.113.199, 192.168.1.100"}, + }, + RemoteAddr: "127.0.0.1:8080", // IP of proxy upstream of our APP + }, + expectIP: "203.0.100.100", // this is first trusted IP in XFF chain + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPFromXFFHeader(tc.givenTrustOptions...)(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) }) } } diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 82e2fbf7..3071eedb 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/base64" "errors" - "fmt" + "net/http" "strconv" "strings" @@ -72,9 +72,11 @@ func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { continue } + // Invalid base64 shouldn't be treated as error + // instead should be treated as invalid client input b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) if errDecode != nil { - lastError = fmt.Errorf("invalid basic auth value: %w", errDecode) + lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode) continue } idx := bytes.IndexByte(b, ':') diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 9580dff0..3d69ae84 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -56,7 +56,7 @@ func TestBasicAuth(t *testing.T) { name: "nok, not base64 Authorization header", givenConfig: defaultConfig, whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, - expectErr: "invalid basic auth value: illegal base64 data at input byte 3", + expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3", }, { name: "nok, missing Authorization header", diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index f367f938..6e7778ea 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -43,6 +43,7 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) { // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) @@ -55,6 +56,7 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) { // Based on content read (overlimit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() diff --git a/middleware/csrf.go b/middleware/csrf.go index acab8790..895a9c63 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -144,7 +144,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { var lastTokenErr error outer: for _, extractor := range extractors { - clientTokens, err := extractor(c) + clientTokens, _, err := extractor(c) if err != nil { lastExtractorErr = err continue diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index c35ed6fa..5651a493 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -79,9 +79,6 @@ func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) { func TestDecompressWithConfig_DefaultConfig(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) h := Decompress()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing @@ -91,10 +88,10 @@ func TestDecompressWithConfig_DefaultConfig(t *testing.T) { // Decompress body := `{"name": "echo"}` gz, _ := gzipString(body) - req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) err := h(c) assert.NoError(t, err) diff --git a/middleware/extractor.go b/middleware/extractor.go index 94134f7e..4ad676f6 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -13,6 +13,24 @@ const ( extractorLimit = 20 ) +// ExtractorSource is type to indicate source for extracted value +type ExtractorSource string + +const ( + // ExtractorSourceHeader means value was extracted from request header + ExtractorSourceHeader ExtractorSource = "header" + // ExtractorSourceQuery means value was extracted from request query parameters + ExtractorSourceQuery ExtractorSource = "query" + // ExtractorSourcePathParam means value was extracted from route path parameters + ExtractorSourcePathParam ExtractorSource = "param" + // ExtractorSourceCookie means value was extracted from request cookies + ExtractorSourceCookie ExtractorSource = "cookie" + // ExtractorSourceForm means value was extracted from request form values + ExtractorSourceForm ExtractorSource = "form" + // ExtractorSourceCustom means value was extracted by custom extractor + ExtractorSourceCustom ExtractorSource = "custom" +) + // ValueExtractorError is error type when middleware extractor is unable to extract value from lookups type ValueExtractorError struct { message string @@ -31,7 +49,7 @@ var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing valu var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"} // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. -type ValuesExtractor func(c echo.Context) ([]string, error) +type ValuesExtractor func(c echo.Context) ([]string, ExtractorSource, error) func createExtractors(lookups string) ([]ValuesExtractor, error) { if lookups == "" { @@ -75,10 +93,10 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { prefixLen := len(valuePrefix) // standard library parses http.Request header keys in canonical form but we may provide something else so fix this header = textproto.CanonicalMIMEHeaderKey(header) - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { values := c.Request().Header.Values(header) if len(values) == 0 { - return nil, errHeaderExtractorValueMissing + return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing } result := make([]string, 0) @@ -100,30 +118,30 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { if len(result) == 0 { if prefixLen > 0 { - return nil, errHeaderExtractorValueInvalid + return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid } - return nil, errHeaderExtractorValueMissing + return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing } - return result, nil + return result, ExtractorSourceHeader, nil } } // valuesFromQuery returns a function that extracts values from the query string. func valuesFromQuery(param string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { result := c.QueryParams()[param] if len(result) == 0 { - return nil, errQueryExtractorValueMissing + return nil, ExtractorSourceQuery, errQueryExtractorValueMissing } else if len(result) > extractorLimit-1 { result = result[:extractorLimit] } - return result, nil + return result, ExtractorSourceQuery, nil } } // valuesFromParam returns a function that extracts values from the url param string. func valuesFromParam(param string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { result := make([]string, 0) for i, p := range c.PathParams() { if param == p.Name { @@ -134,18 +152,18 @@ func valuesFromParam(param string) ValuesExtractor { } } if len(result) == 0 { - return nil, errParamExtractorValueMissing + return nil, ExtractorSourcePathParam, errParamExtractorValueMissing } - return result, nil + return result, ExtractorSourcePathParam, nil } } // valuesFromCookie returns a function that extracts values from the named cookie. func valuesFromCookie(name string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { cookies := c.Cookies() if len(cookies) == 0 { - return nil, errCookieExtractorValueMissing + return nil, ExtractorSourceCookie, errCookieExtractorValueMissing } result := make([]string, 0) @@ -158,26 +176,26 @@ func valuesFromCookie(name string) ValuesExtractor { } } if len(result) == 0 { - return nil, errCookieExtractorValueMissing + return nil, ExtractorSourceCookie, errCookieExtractorValueMissing } - return result, nil + return result, ExtractorSourceCookie, nil } } // valuesFromForm returns a function that extracts values from the form field. func valuesFromForm(name string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { - if parseErr := c.Request().ParseForm(); parseErr != nil { - return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr) + return func(c echo.Context) ([]string, ExtractorSource, error) { + if c.Request().Form == nil { + _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does } values := c.Request().Form[name] if len(values) == 0 { - return nil, errFormExtractorValueMissing + return nil, ExtractorSourceForm, errFormExtractorValueMissing } if len(values) > extractorLimit-1 { values = values[:extractorLimit] } result := append([]string{}, values...) - return result, nil + return result, ExtractorSourceForm, nil } } diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index 439c4d8f..afa776ec 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -1,9 +1,11 @@ package middleware import ( + "bytes" "fmt" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" + "mime/multipart" "net/http" "net/http/httptest" "net/url" @@ -18,6 +20,7 @@ func TestCreateExtractors(t *testing.T) { givenPathParams echo.PathParams whenLoopups string expectValues []string + expectSource ExtractorSource expectCreateError string expectError string }{ @@ -30,6 +33,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "header:Authorization:Bearer ", expectValues: []string{"token"}, + expectSource: ExtractorSourceHeader, }, { name: "ok, form", @@ -43,6 +47,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "form:name", expectValues: []string{"Jon Snow"}, + expectSource: ExtractorSourceForm, }, { name: "ok, cookie", @@ -53,6 +58,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "cookie:_csrf", expectValues: []string{"token"}, + expectSource: ExtractorSourceCookie, }, { name: "ok, param", @@ -61,6 +67,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "param:id", expectValues: []string{"123"}, + expectSource: ExtractorSourcePathParam, }, { name: "ok, query", @@ -70,6 +77,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "query:id", expectValues: []string{"999"}, + expectSource: ExtractorSourceQuery, }, { name: "nok, invalid lookup", @@ -100,8 +108,9 @@ func TestCreateExtractors(t *testing.T) { assert.NoError(t, err) for _, e := range extractors { - values, eErr := e(c) + values, source, eErr := e(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, tc.expectSource, source) if tc.expectError != "" { assert.EqualError(t, eErr, tc.expectError) return @@ -226,8 +235,9 @@ func TestValuesFromHeader(t *testing.T) { extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceHeader, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -287,8 +297,9 @@ func TestValuesFromQuery(t *testing.T) { extractor := valuesFromQuery(tc.whenName) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceQuery, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -366,8 +377,9 @@ func TestValuesFromParam(t *testing.T) { extractor := valuesFromParam(tc.whenName) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourcePathParam, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -446,8 +458,9 @@ func TestValuesFromCookie(t *testing.T) { extractor := valuesFromCookie(tc.whenName) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceCookie, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -483,6 +496,25 @@ func TestValuesFromForm(t *testing.T) { return req } + exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request { + var b bytes.Buffer + w := multipart.NewWriter(&b) + w.WriteField("name", "Jon Snow") + w.WriteField("emails[]", "jon@labstack.com") + if mod != nil { + mod(w) + } + + fw, _ := w.CreateFormFile("upload", "my.file") + fw.Write([]byte(`
hi
`)) + w.Close() + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String())) + req.Header.Add(echo.HeaderContentType, w.FormDataContentType()) + + return req + } + var testCases = []struct { name string givenRequest *http.Request @@ -504,6 +536,14 @@ func TestValuesFromForm(t *testing.T) { whenName: "emails[]", expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, + { + name: "ok, POST multipart/form, multiple value", + givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) { + w.WriteField("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, { name: "ok, GET form, single value", givenRequest: exampleGetFormRequest(nil), @@ -524,16 +564,6 @@ func TestValuesFromForm(t *testing.T) { whenName: "nope", expectError: errFormExtractorValueMissing.Error(), }, - { - name: "nok, POST form, form parsing error", - givenRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodPost, "/", nil) - req.Body = nil - return req - }(), - whenName: "name", - expectError: "valuesFromForm parse form failed: missing form body", - }, { name: "ok, cut values over extractorLimit", givenRequest: examplePostFormRequest(func(v *url.Values) { @@ -559,8 +589,9 @@ func TestValuesFromForm(t *testing.T) { extractor := valuesFromForm(tc.whenName) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceForm, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { diff --git a/middleware/jwt.go b/middleware/jwt.go index 40b45e77..cf30237b 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -64,7 +64,7 @@ type JWTConfig struct { // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token // parsing fails or parsed token is invalid. // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library - ParseTokenFunc func(c echo.Context, auth string) (interface{}, error) + ParseTokenFunc func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) } // JWTSuccessHandler defines a function which is executed for a valid token. @@ -101,7 +101,7 @@ var DefaultJWTConfig = JWTConfig{ // For missing token, it returns "400 - Bad Request" error. // // See: https://jwt.io/introduction -func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc { +func JWT(parseTokenFunc func(c echo.Context, auth string, source ExtractorSource) (interface{}, error)) echo.MiddlewareFunc { c := DefaultJWTConfig c.ParseTokenFunc = parseTokenFunc return JWTWithConfig(c) @@ -152,13 +152,13 @@ func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) { var lastExtractorErr error var lastTokenErr error for _, extractor := range extractors { - auths, extrErr := extractor(c) + auths, source, extrErr := extractor(c) if extrErr != nil { lastExtractorErr = extrErr continue } for _, auth := range auths { - token, err := config.ParseTokenFunc(c, auth) + token, err := config.ParseTokenFunc(c, auth, source) if err != nil { lastTokenErr = err continue diff --git a/middleware/jwt_external_test.go b/middleware/jwt_external_test.go index 1b92f188..f854fcf4 100644 --- a/middleware/jwt_external_test.go +++ b/middleware/jwt_external_test.go @@ -21,7 +21,7 @@ import ( // This is one of the options to provide a token validation key. // The order of precedence is a user-defined SigningKeys and SigningKey. // Required if signingKey is not provided -func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) { +func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string, source middleware.ExtractorSource) (interface{}, error) { // keyFunc defines a user-defined function that supplies the public key for a token validation. // The function shall take care of verifying the signing algorithm and selecting the proper key. // A user-defined KeyFunc can be useful if tokens are issued by an external party. @@ -41,7 +41,7 @@ func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]in return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) } - return func(c echo.Context, auth string) (interface{}, error) { + return func(c echo.Context, auth string, source middleware.ExtractorSource) (interface{}, error) { token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here if err != nil { return nil, err diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 5e5b9912..00472373 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -14,7 +14,7 @@ import ( "github.com/stretchr/testify/assert" ) -func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) { +func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) { // This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running keyFunc := func(t *jwt.Token) (interface{}, error) { if t.Method.Alg() != signingMethod { @@ -23,7 +23,7 @@ func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface return signingKey, nil } - return func(c echo.Context, auth string) (interface{}, error) { + return func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) { token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) if err != nil { return nil, err @@ -405,7 +405,6 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { { name: "ok, ErrorHandler is executed", given: JWTConfig{ - ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), ErrorHandler: func(c echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) }, @@ -424,7 +423,7 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { config := tc.given parseTokenCalled := false - config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) { + config.ParseTokenFunc = func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) { parseTokenCalled = true return nil, errors.New("parsing failed") } @@ -453,8 +452,10 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { // with current JWT middleware signingKey := []byte("secret") + var fromSource ExtractorSource config := JWTConfig{ - ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + ParseTokenFunc: func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) { + fromSource = source keyFunc := func(t *jwt.Token) (interface{}, error) { if t.Method.Alg() != "HS256" { return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) @@ -481,6 +482,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { res := httptest.NewRecorder() e.ServeHTTP(res, req) + assert.Equal(t, fromSource, ExtractorSourceHeader) assert.Equal(t, http.StatusTeapot, res.Code) } @@ -494,7 +496,7 @@ func TestMustJWTWithConfig_SuccessHandler(t *testing.T) { }) mw, err := JWTConfig{ - ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + ParseTokenFunc: func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) { return auth, nil }, SuccessHandler: func(c echo.Context) { @@ -616,7 +618,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { mw, err := JWTConfig{ ContinueOnIgnoredError: tc.givenContinueOnIgnoredError, TokenLookup: tc.givenTokenLookup, - ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + ParseTokenFunc: func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) { return tc.whenParseReturn, tc.whenParseError }, ErrorHandler: tc.givenErrorHandler, @@ -648,8 +650,8 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) { e.Use(JWTWithConfig(JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), TokenLookupFuncs: []ValuesExtractor{ - func(c echo.Context) ([]string, error) { - return []string{c.Request().Header.Get("X-API-Key")}, nil + func(c echo.Context) ([]string, ExtractorSource, error) { + return []string{c.Request().Header.Get("X-API-Key")}, ExtractorSourceCustom, nil }, }, })) diff --git a/middleware/key_auth.go b/middleware/key_auth.go index 77a001ea..d45142dd 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -50,7 +50,7 @@ type KeyAuthConfig struct { } // KeyAuthValidator defines a function to validate KeyAuth credentials. -type KeyAuthValidator func(c echo.Context, key string) (bool, error) +type KeyAuthValidator func(c echo.Context, key string, source ExtractorSource) (bool, error) // KeyAuthErrorHandler defines a function which is executed for an invalid key. type KeyAuthErrorHandler func(c echo.Context, err error) error @@ -116,13 +116,13 @@ func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { var lastExtractorErr error var lastValidatorErr error for _, extractor := range extractors { - keys, extrErr := extractor(c) + keys, source, extrErr := extractor(c) if extrErr != nil { lastExtractorErr = extrErr continue } for _, key := range keys { - valid, err := config.Validator(c, key) + valid, err := config.Validator(c, key, source) if err != nil { lastValidatorErr = err continue diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index 1b64865f..fa182e6c 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/assert" ) -func testKeyValidator(c echo.Context, key string) (bool, error) { +func testKeyValidator(c echo.Context, key string, source ExtractorSource) (bool, error) { switch key { case "valid-key": return true, nil @@ -218,6 +218,25 @@ func TestKeyAuthWithConfig(t *testing.T) { expectHandlerCalled: false, expectError: "code=401, message=Unauthorized, internal=some user defined error", }, + { + name: "ok, custom validator checks source", + givenRequest: func(req *http.Request) { + q := req.URL.Query() + q.Add("key", "valid-key") + req.URL.RawQuery = q.Encode() + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "query:key" + conf.Validator = func(c echo.Context, key string, source ExtractorSource) (bool, error) { + if source == ExtractorSourceQuery { + return true, nil + } + return false, errors.New("invalid source") + } + + }, + expectHandlerCalled: true, + }, } for _, tc := range testCases { @@ -267,7 +286,7 @@ func TestKeyAuthWithConfig_errors(t *testing.T) { { name: "ok, no error", whenConfig: KeyAuthConfig{ - Validator: func(c echo.Context, key string) (bool, error) { + Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) { return false, nil }, }, @@ -283,7 +302,7 @@ func TestKeyAuthWithConfig_errors(t *testing.T) { name: "ok, extractor source can not be split", whenConfig: KeyAuthConfig{ KeyLookup: "nope", - Validator: func(c echo.Context, key string) (bool, error) { + Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) { return false, nil }, }, @@ -293,7 +312,7 @@ func TestKeyAuthWithConfig_errors(t *testing.T) { name: "ok, no extractors", whenConfig: KeyAuthConfig{ KeyLookup: "nope:nope", - Validator: func(c echo.Context, key string) (bool, error) { + Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) { return false, nil }, }, diff --git a/middleware/logger.go b/middleware/logger.go index bd2d3d93..0e525e74 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -22,6 +22,8 @@ type LoggerConfig struct { // Tags to construct the logger format. // // - time_unix + // - time_unix_milli + // - time_unix_micro // - time_unix_nano // - time_rfc3339 // - time_rfc3339_nano @@ -119,6 +121,10 @@ func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { switch tag { case "time_unix": return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) + case "time_unix_milli": + return buf.WriteString(strconv.FormatInt(time.Now().UnixMilli(), 10)) + case "time_unix_micro": + return buf.WriteString(strconv.FormatInt(time.Now().UnixMicro(), 10)) case "time_unix_nano": return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10)) case "time_rfc3339": diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 2f1230dd..455520f9 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strconv" "strings" "testing" "time" @@ -172,6 +173,52 @@ func TestLoggerCustomTimestamp(t *testing.T) { assert.Error(t, err) } +func TestLoggerTemplateWithTimeUnixMilli(t *testing.T) { + buf := new(bytes.Buffer) + + e := echo.New() + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `${time_unix_milli}`, + Output: buf, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + unixMillis, err := strconv.ParseInt(buf.String(), 10, 64) + assert.NoError(t, err) + assert.WithinDuration(t, time.Unix(unixMillis/1000, 0), time.Now(), 3*time.Second) +} + +func TestLoggerTemplateWithTimeUnixMicro(t *testing.T) { + buf := new(bytes.Buffer) + + e := echo.New() + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `${time_unix_micro}`, + Output: buf, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + unixMicros, err := strconv.ParseInt(buf.String(), 10, 64) + assert.NoError(t, err) + assert.WithinDuration(t, time.Unix(unixMicros/1000000, 0), time.Now(), 3*time.Second) +} + func BenchmarkLoggerWithConfig_withoutMapFields(b *testing.B) { e := echo.New() diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 266a575b..4ca10b84 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -32,7 +32,6 @@ func TestMethodOverride(t *testing.T) { func TestMethodOverride_formParam(t *testing.T) { e := echo.New() - m := MethodOverride() h := func(c echo.Context) error { return c.String(http.StatusOK, "test") } @@ -53,7 +52,6 @@ func TestMethodOverride_formParam(t *testing.T) { func TestMethodOverride_queryParam(t *testing.T) { e := echo.New() - m := MethodOverride() h := func(c echo.Context) error { return c.String(http.StatusOK, "test") } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 1d0dee91..c2ae7755 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -341,10 +341,9 @@ func TestProxyError(t *testing.T) { e := echo.New() e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() // Remote unreachable - rec = httptest.NewRecorder() + rec := httptest.NewRecorder() req.URL.Path = "/api/users" e.ServeHTTP(rec, req) assert.Equal(t, "/api/users", req.URL.Path) diff --git a/middleware/recover.go b/middleware/recover.go index 70e98b26..7e46ccd7 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "net/http" "runtime" "github.com/labstack/echo/v5" @@ -63,6 +64,9 @@ func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) { defer func() { if r := recover(); r != nil { + if r == http.ErrAbortHandler { + panic(r) + } tmpErr, ok := r.(error) if !ok { tmpErr = fmt.Errorf("%v", r) diff --git a/middleware/recover_test.go b/middleware/recover_test.go index a65df541..f8d0db5e 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -51,6 +51,33 @@ func TestRecover_skipper(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain } +func TestRecoverErrAbortHandler(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Recover()(func(c echo.Context) error { + panic(http.ErrAbortHandler) + }) + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`") + } else { + if err, ok := r.(error); ok { + assert.ErrorIs(t, err, http.ErrAbortHandler) + } else { + assert.Fail(t, "not of error type") + } + } + }() + + hErr := h(c) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.NotContains(t, hErr.Error(), "PANIC RECOVER") +} + func TestRecoverWithConfig(t *testing.T) { var testCases = []struct { name string diff --git a/router.go b/router.go index d4bb3077..42d3ec90 100644 --- a/router.go +++ b/router.go @@ -8,6 +8,15 @@ import ( ) // Router is interface for routing request contexts to registered routes. +// +// Contract between Echo/Context instance and the router: +// * all routes must be added through methods on echo.Echo instance. +// Reason: Echo instance uses RouteInfo.Params() length to allocate slice for paths parameters (see `Echo.contextPathParamAllocSize`). +// * Router must populate Context during Router.Route call with: +// * RoutableContext.SetPath +// * RoutableContext.SetRawPathParams (IMPORTANT! with same slice pointer that c.RawPathParams() returns) +// * RoutableContext.SetRouteInfo +// And optionally can set additional information to Context with RoutableContext.Set type Router interface { // Add registers Routable with the Router and returns registered RouteInfo Add(routable Routable) (RouteInfo, error) @@ -19,11 +28,6 @@ type Router interface { // Route searches Router for matching route and applies it to the given context. In case when no matching method // was not found (405) or no matching route exists for path (404), router will return its implementation of 405/404 // handler function. - // When implementing custom Router make sure to populate Context during Router.Route call with: - // * RoutableContext.SetPath - // * RoutableContext.SetRawPathParams (IMPORTANT! with same slice pointer that c.RawPathParams() returns) - // * RoutableContext.SetRouteInfo - // And optionally can set additional information to Context with RoutableContext.Set Route(c RoutableContext) HandlerFunc } @@ -210,18 +214,22 @@ type routeMethod struct { } type routeMethods struct { - connect *routeMethod - delete *routeMethod - get *routeMethod - head *routeMethod - options *routeMethod - patch *routeMethod - post *routeMethod - propfind *routeMethod - put *routeMethod - trace *routeMethod - report *routeMethod - anyOther map[string]*routeMethod + connect *routeMethod + delete *routeMethod + get *routeMethod + head *routeMethod + options *routeMethod + patch *routeMethod + post *routeMethod + propfind *routeMethod + put *routeMethod + trace *routeMethod + report *routeMethod + anyOther map[string]*routeMethod + + // notFoundHandler is handler registered with RouteNotFound method and is executed for 404 cases + notFoundHandler *routeMethod + allowHeader string } @@ -249,6 +257,9 @@ func (m *routeMethods) set(method string, r *routeMethod) { m.trace = r case REPORT: m.report = r + case RouteNotFound: + m.notFoundHandler = r + return // RouteNotFound/404 is not considered as a handler so no further logic needs to be executed default: if m.anyOther == nil { m.anyOther = make(map[string]*routeMethod) @@ -353,6 +364,7 @@ func (m *routeMethods) isHandler() bool { m.trace != nil || m.report != nil || len(m.anyOther) != 0 + // RouteNotFound/404 is not considered as a handler } // Routes returns all registered routes @@ -611,7 +623,12 @@ func (r *DefaultRouter) insert(t kind, path string, method string, ri routeMetho } currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else if lcpLen < prefixLen { - // Split node + // Split node into two before we insert new node. + // This happens when we are inserting path that is submatch of any existing inserted paths. + // For example, we have node `/test` and now are about to insert `/te/*`. In that case + // 1. overlapping part is `/te` that is used as parent node + // 2. `st` is part from existing node that is not matching - it gets its own node (child to `/te`) + // 3. `/*` is the new part we are about to insert (child to `/te`) n := newNode( currentNode.kind, currentNode.prefix[lcpLen:], @@ -735,10 +752,8 @@ func (n *node) findStaticChild(l byte) *node { } func (n *node) findChildWithLabel(l byte) *node { - for _, c := range n.staticChildren { - if c.label == l { - return c - } + if c := n.findStaticChild(l); c != nil { + return c } if l == paramLabel { return n.paramChild @@ -751,11 +766,7 @@ func (n *node) findChildWithLabel(l byte) *node { func (n *node) setHandler(method string, r *routeMethod) { n.methods.set(method, r) - if r != nil && r.handler != nil { - n.isHandler = true - } else { - n.isHandler = n.methods.isHandler() - } + n.isHandler = n.methods.isHandler() } // Note: notFoundRouteInfo exists to avoid allocations when setting 404 RouteInfo to Context @@ -900,7 +911,7 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc { // No matching prefix, let's backtrack to the first possible alternative node of the decision path nk, ok := backtrackToNextNodeKind(staticKind) if !ok { - break // No other possibilities on the decision path + break // No other possibilities on the decision path, handler will be whatever context is reset to. } else if nk == paramKind { goto Param // NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently @@ -916,15 +927,21 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc { search = search[lcpLen:] searchIndex = searchIndex + lcpLen - // Finish routing if no remaining search and we are on a node with handler and matching method type - if search == "" && currentNode.isHandler { - // check if current node has handler registered for http method we are looking for. we store currentNode as - // best matching in case we do no find no more routes matching this path+method - if previousBestMatchNode == nil { - previousBestMatchNode = currentNode - } - if rMethod := currentNode.methods.find(req.Method); rMethod != nil { - matchedRouteMethod = rMethod + // Finish routing if is no request path remaining to search + if search == "" { + // in case of node that is handler we have exact method type match or something for 405 to use + if currentNode.isHandler { + // check if current node has handler registered for http method we are looking for. we store currentNode as + // best matching in case we do no find no more routes matching this path+method + if previousBestMatchNode == nil { + previousBestMatchNode = currentNode + } + if h := currentNode.methods.find(req.Method); h != nil { + matchedRouteMethod = h + break + } + } else if currentNode.methods.notFoundHandler != nil { + matchedRouteMethod = currentNode.methods.notFoundHandler break } } @@ -944,7 +961,8 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc { i := 0 l := len(search) if currentNode.isLeaf { - // when param node does not have any children then param node should act similarly to any node - consider all remaining search as match + // when param node does not have any children (path param is last piece of route path) then param node should + // act similarly to any node - consider all remaining search as match i = l } else { for ; i < l && search[i] != '/'; i++ { @@ -969,13 +987,16 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc { searchIndex += +len(search) search = "" - // check if current node has handler registered for http method we are looking for. we store currentNode as - // best matching in case we do no find no more routes matching this path+method + if rMethod := currentNode.methods.find(req.Method); rMethod != nil { + matchedRouteMethod = rMethod + break + } + // we store currentNode as best matching in case we do not find more routes matching this path+method. Needed for 405 if previousBestMatchNode == nil { previousBestMatchNode = currentNode } - if rMethod := currentNode.methods.find(req.Method); rMethod != nil { - matchedRouteMethod = rMethod + if currentNode.methods.notFoundHandler != nil { + matchedRouteMethod = currentNode.methods.notFoundHandler break } } @@ -1017,7 +1038,13 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc { rPath = currentNode.originalPath rInfo = notFoundRouteInfo - if currentNode.isHandler { + if currentNode.methods.notFoundHandler != nil { + matchedRouteMethod = currentNode.methods.notFoundHandler + + rInfo = matchedRouteMethod.routeInfo + rPath = matchedRouteMethod.path + rHandler = matchedRouteMethod.handler + } else if currentNode.isHandler { rInfo = methodNotAllowedRouteInfo c.Set(ContextKeyHeaderAllow, currentNode.methods.allowHeader) diff --git a/router_test.go b/router_test.go index 3bf14ad0..83d3b1b2 100644 --- a/router_test.go +++ b/router_test.go @@ -1,7 +1,6 @@ package echo import ( - "fmt" "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" @@ -2243,6 +2242,64 @@ func TestRouter_Match_DifferentParamNamesForSamePlace(t *testing.T) { } } +// Issue #2164 - this test is meant to document path parameter behaviour when request url has empty value in place +// of the path parameter. As tests show the result is different depending on where parameter exists in the route path. +func TestDefaultRouter_PathParamsCanMatchEmptyValues(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute string + expectParam map[string]string + expectError error + }{ + { + name: "ok, route is matched with even empty param is in the middle and between slashes", + whenURL: "/a//b", + expectRoute: "/a/:id/b", + expectParam: map[string]string{"id": ""}, + }, + { + name: "ok, route is matched with even empty param is in the middle", + whenURL: "/a2/b", + expectRoute: "/a2:id/b", + expectParam: map[string]string{"id": ""}, + }, + { + name: "ok, route is NOT matched with even empty param is at the end", + whenURL: "/a3/", + expectRoute: "", + expectParam: map[string]string{}, + expectError: ErrNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + e.GET("/a/:id/b", handlerFunc) + e.GET("/a2:id/b", handlerFunc) + e.GET("/a3/:id", handlerFunc) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + c := e.NewContext(req, nil).(*DefaultContext) + handler := e.router.Route(c) + + err := handler(c) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Path()) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.pathParams.Get(param, "")) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + // Issue #729 func TestRouterParamAlias(t *testing.T) { api := []testRoute{ @@ -3133,6 +3190,232 @@ func TestDefaultRouter_OptionsMethodHandler(t *testing.T) { assert.Equal(t, "not empty", body) } +func TestRouter_RouteWhenNotFoundRouteWithNodeSplitting(t *testing.T) { + e := New() + r := e.router + + hf := func(c Context) error { + return c.String(http.StatusOK, c.RouteInfo().Name()) + } + r.Add(Route{Method: http.MethodGet, Path: "/test*", Handler: hf, Name: "0"}) + r.Add(Route{Method: RouteNotFound, Path: "/test*", Handler: hf, Name: "1"}) + r.Add(Route{Method: RouteNotFound, Path: "/test", Handler: hf, Name: "2"}) + + // Tree before: + // 1 `/` + // 1.1 `*` (any) ID=1 + // 1.2 `test` (static) ID=2 + // 1.2.1 `*` (any) ID=0 + + // node with path `test` has routeNotFound handler from previous Add call. Now when we insert `/te/st*` into router tree + // This means that node `test` is split into `te` and `st` nodes and new node `/st*` is inserted. + // On that split `/test` routeNotFound handler must not be lost. + r.Add(Route{Method: http.MethodGet, Path: "/te/st*", Handler: hf, Name: "3"}) + // Tree after: + // 1 `/` + // 1.1 `*` (any) ID=1 + // 1.2 `te` (static) + // 1.2.1 `st` (static) ID=2 + // 1.2.1.1 `*` (any) ID=0 + // 1.2.2 `/st` (static) + // 1.2.2.1 `*` (any) ID=3 + + _, body := request(http.MethodPut, "/test", e) + + assert.Equal(t, "2", body) +} + +func TestRouter_RouteWhenNotFoundRouteAnyKind(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectID int + expectParam map[string]string + }{ + { + name: "route not existent /xx to not found handler /*", + whenURL: "/xx", + expectRoute: "/*", + expectID: 4, + expectParam: map[string]string{"*": "xx"}, + }, + { + name: "route not existent /a/xx to not found handler /a/*", + whenURL: "/a/xx", + expectRoute: "/a/*", + expectID: 5, + expectParam: map[string]string{"*": "xx"}, + }, + { + name: "route not existent /a/c/dxxx to not found handler /a/c/d*", + whenURL: "/a/c/dxxx", + expectRoute: "/a/c/d*", + expectID: 6, + expectParam: map[string]string{"*": "xxx"}, + }, + { + name: "route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "/a/c/df", + expectID: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.contextPathParamAllocSize = 1 + r := e.router + + r.Add(Route{Method: http.MethodGet, Path: "/", Handler: handlerHelper("ID", 0), Name: "0"}) + r.Add(Route{Method: http.MethodGet, Path: "/a/c/df", Handler: handlerHelper("ID", 1), Name: "1"}) + r.Add(Route{Method: http.MethodGet, Path: "/a/b*", Handler: handlerHelper("ID", 2), Name: "2"}) + r.Add(Route{Method: http.MethodPut, Path: "/*", Handler: handlerHelper("ID", 3), Name: "3"}) + + r.Add(Route{Method: RouteNotFound, Path: "/a/c/d*", Handler: handlerHelper("ID", 6), Name: "6"}) + r.Add(Route{Method: RouteNotFound, Path: "/a/*", Handler: handlerHelper("ID", 5), Name: "5"}) + r.Add(Route{Method: RouteNotFound, Path: "/*", Handler: handlerHelper("ID", 4), Name: "4"}) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + c := e.NewContext(req, nil).(*DefaultContext) + + handler := r.Route(c) + handler(c) + + testValue, _ := c.Get("ID").(int) + assert.Equal(t, tc.expectID, testValue) + assert.Equal(t, tc.expectRoute, c.Path()) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.PathParam(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + +func TestRouter_RouteWhenNotFoundRouteParamKind(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectID int + expectParam map[string]string + }{ + { + name: "route not existent /xx to not found handler /:file", + whenURL: "/xx", + expectRoute: "/:file", + expectID: 4, + expectParam: map[string]string{"file": "xx"}, + }, + { + name: "route not existent /a/xx to not found handler /a/:file", + whenURL: "/a/xx", + expectRoute: "/a/:file", + expectID: 5, + expectParam: map[string]string{"file": "xx"}, + }, + { + name: "route not existent /a/c/dxxx to not found handler /a/c/d:file", + whenURL: "/a/c/dxxx", + expectRoute: "/a/c/d:file", + expectID: 6, + expectParam: map[string]string{"file": "xxx"}, + }, + { + name: "route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "/a/c/df", + expectID: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.contextPathParamAllocSize = 1 + r := e.router + + r.Add(Route{Method: http.MethodGet, Path: "/", Handler: handlerHelper("ID", 0), Name: "0"}) + r.Add(Route{Method: http.MethodGet, Path: "/a/c/df", Handler: handlerHelper("ID", 1), Name: "1"}) + r.Add(Route{Method: http.MethodGet, Path: "/a/b*", Handler: handlerHelper("ID", 2), Name: "2"}) + r.Add(Route{Method: http.MethodPut, Path: "/*", Handler: handlerHelper("ID", 3), Name: "3"}) + + r.Add(Route{Method: RouteNotFound, Path: "/a/c/d:file", Handler: handlerHelper("ID", 6), Name: "6"}) + r.Add(Route{Method: RouteNotFound, Path: "/a/:file", Handler: handlerHelper("ID", 5), Name: "5"}) + r.Add(Route{Method: RouteNotFound, Path: "/:file", Handler: handlerHelper("ID", 4), Name: "4"}) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + c := e.NewContext(req, nil).(*DefaultContext) + + handler := r.Route(c) + handler(c) + + testValue, _ := c.Get("ID").(int) + assert.Equal(t, tc.expectID, testValue) + assert.Equal(t, tc.expectRoute, c.Path()) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.PathParam(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + +func TestRouter_RouteWhenNotFoundRouteStaticKind(t *testing.T) { + // note: static not found handler is quite silly thing to have but we still support it + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectID int + expectParam map[string]string + }{ + { + name: "route not existent / to not found handler /", + whenURL: "/", + expectRoute: "/", + expectID: 3, + expectParam: map[string]string{}, + }, + { + name: "route /a to /a", + whenURL: "/a", + expectRoute: "/a", + expectID: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.contextPathParamAllocSize = 1 + r := e.router + + r.Add(Route{Method: http.MethodPut, Path: "/", Handler: handlerHelper("ID", 0), Name: "0"}) + r.Add(Route{Method: http.MethodGet, Path: "/a", Handler: handlerHelper("ID", 1), Name: "1"}) + r.Add(Route{Method: http.MethodPut, Path: "/*", Handler: handlerHelper("ID", 2), Name: "2"}) + + r.Add(Route{Method: RouteNotFound, Path: "/", Handler: handlerHelper("ID", 3), Name: "3"}) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + c := e.NewContext(req, nil).(*DefaultContext) + + handler := r.Route(c) + handler(c) + + testValue, _ := c.Get("ID").(int) + assert.Equal(t, tc.expectID, testValue) + assert.Equal(t, tc.expectRoute, c.Path()) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.PathParam(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + type mySimpleRouter struct { route Route } @@ -3227,33 +3510,3 @@ func BenchmarkRouterGooglePlusAPIMisses(b *testing.B) { func BenchmarkRouterParamsAndAnyAPI(b *testing.B) { benchmarkRouterRoutes(b, paramAndAnyAPI, paramAndAnyAPIToFind) } - -func (n *node) printTree(pfx string, tail bool) { - p := prefix(tail, pfx, "└── ", "├── ") - fmt.Printf("%s%s, %p: type=%d, parent=%p, handler=%v, paramNames=%v\n", p, n.prefix, n, n.kind, n.parent, n.methods, n.paramsCount) - - p = prefix(tail, pfx, " ", "│ ") - - children := n.staticChildren - l := len(children) - - if n.paramChild != nil { - n.paramChild.printTree(p, n.anyChild == nil && l == 0) - } - if n.anyChild != nil { - n.anyChild.printTree(p, l == 0) - } - for i := 0; i < l-1; i++ { - children[i].printTree(p, false) - } - if l > 0 { - children[l-1].printTree(p, true) - } -} - -func prefix(tail bool, p, on, off string) string { - if tail { - return fmt.Sprintf("%s%s", p, on) - } - return fmt.Sprintf("%s%s", p, off) -} diff --git a/server_test.go b/server_test.go index d1751e92..fa6107dd 100644 --- a/server_test.go +++ b/server_test.go @@ -696,8 +696,7 @@ func TestStartConfig_WithHidePort(t *testing.T) { } assert.NoError(t, <-errCh) - portMsg := fmt.Sprintf("http(s) server started on") - contains := strings.Contains(buf.String(), portMsg) + contains := strings.Contains(buf.String(), "http(s) server started on") if tc.hidePort { assert.False(t, contains) } else {