diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index d1967212..af6a5a05 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -19,7 +19,7 @@ on: - '_fixture/**' - '.github/**' - 'codecov.yml' - workflow_dispatch: # to be able to run workflow manually + workflow_dispatch: jobs: test: @@ -29,72 +29,72 @@ jobs: # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases # except v5 starts from 1.17 until there is last four major releases after that - go: [1.17] + go: [1.17, 1.18] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v2 - with: - go-version: ${{ matrix.go }} - - name: Checkout Code - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ github.ref }} + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - name: Install Dependencies - run: go get -v golang.org/x/lint/golint + run: | + go install golang.org/x/lint/golint@latest + go install honnef.co/go/tools/cmd/staticcheck@latest - name: Run Tests run: | golint -set_exit_status ./... + staticcheck ./... go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... - - name: Upload coverage to Codecov - if: success() && matrix.go == 1.17 && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v2 + if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest' + uses: codecov/codecov-action@v1 with: + token: fail_ci_if_error: false - benchmark: needs: test strategy: matrix: os: [ubuntu-latest] - go: [1.17] + go: [1.18] name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v2 - with: - go-version: ${{ matrix.go }} - - name: Checkout Code (Previous) - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ github.base_ref }} path: previous - name: Checkout Code (New) - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: path: new + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - name: Install Dependencies - run: go get -v golang.org/x/perf/cmd/benchstat + run: go install golang.org/x/perf/cmd/benchstat@latest - name: Run Benchmark (Previous) run: | cd previous go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt - - name: Run Benchmark (New) run: | cd new go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt - - name: Run Benchstat run: | benchstat previous/benchmark.txt new/benchmark.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 372ed13c..ba75d71f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,50 @@ # Changelog +## v4.7.2 - 2022-03-16 + +**Fixes** + +* Fix nil pointer exception when calling Start again after address binding error [#2131](https://github.com/labstack/echo/pull/2131) +* Fix CSRF middleware not being able to extract token from multipart/form-data form [#2136](https://github.com/labstack/echo/pull/2136) +* Fix Timeout middleware write race [#2126](https://github.com/labstack/echo/pull/2126) + +**Enhancements** + +* Recover middleware should not log panic for aborted handler [#2134](https://github.com/labstack/echo/pull/2134) + + +## v4.7.1 - 2022-03-13 + +**Fixes** + +* Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) [#2123](https://github.com/labstack/echo/pull/2123) + +**Enhancements** + +* Remove some unused code [#2116](https://github.com/labstack/echo/pull/2116) + + +## v4.7.0 - 2022-03-01 + +**Enhancements** + +* Add JWT, KeyAuth, CSRF multivalue extractors [#2060](https://github.com/labstack/echo/pull/2060) +* Add LogErrorFunc to recover middleware [#2072](https://github.com/labstack/echo/pull/2072) +* Add support for HEAD method query params binding [#2027](https://github.com/labstack/echo/pull/2027) +* Improve filesystem support with echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS [#2064](https://github.com/labstack/echo/pull/2064) + +**Fixes** + +* Fix X-Real-IP bug, improve tests [#2007](https://github.com/labstack/echo/pull/2007) +* Minor syntax fixes [#1994](https://github.com/labstack/echo/pull/1994), [#2102](https://github.com/labstack/echo/pull/2102), [#2102](https://github.com/labstack/echo/pull/2102) + +**General** + +* Add cache-control and connection headers [#2103](https://github.com/labstack/echo/pull/2103) +* Add Retry-After header constant [#2078](https://github.com/labstack/echo/pull/2078) +* Upgrade `go` directive in `go.mod` to 1.17 [#2049](https://github.com/labstack/echo/pull/2049) +* Add Pagoda [#2077](https://github.com/labstack/echo/pull/2077) and Souin [#2069](https://github.com/labstack/echo/pull/2069) to 3rd-party middlewares in README + ## v4.6.3 - 2022-01-10 **Fixes** diff --git a/Makefile b/Makefile index 10f9c8f5..8149aeba 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,8 @@ tag: check: lint vet race ## Check project init: - @go get -u golang.org/x/lint/golint + @go install golang.org/x/lint/golint@latest + @go install honnef.co/go/tools/cmd/staticcheck@latest lint: ## Lint the files @golint -set_exit_status ${PKG_LIST} @@ -29,6 +30,6 @@ benchmark: ## Run benchmarks help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.16" -test_version: ## Run tests inside Docker with given version (defaults to 1.16 oldest supported). Example: make test_version goversion=1.16 +goversion ?= "1.17" +test_version: ## Run tests inside Docker with given version (defaults to 1.17 oldest supported). Example: make test_version goversion=1.17 @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" diff --git a/binder.go b/binder.go index 402b80bc..e022a7e8 100644 --- a/binder.go +++ b/binder.go @@ -1,6 +1,8 @@ package echo import ( + "encoding" + "encoding/json" "fmt" "net/http" "strconv" @@ -52,8 +54,11 @@ import ( * time * duration * BindUnmarshaler() interface + * TextUnmarshaler() interface + * JSONUnmarshaler() interface * UnixTime() - converts unix time (integer) to time.Time - * UnixTimeNano() - converts unix time with nano second precision (integer) to time.Time + * UnixTimeMilli() - converts unix time with millisecond precision (integer) to time.Time + * UnixTimeNano() - converts unix time with nanosecond precision (integer) to time.Time * CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error` */ @@ -204,7 +209,7 @@ func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []st return b.customFunc(sourceParam, customFunc, false) } -// MustCustomFunc requires parameter values to exist to be bind with Func. Returns error when value does not exist. +// MustCustomFunc requires parameter values to exist to bind with Func. Returns error when value does not exist. func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder { return b.customFunc(sourceParam, customFunc, true) } @@ -241,7 +246,7 @@ func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder { return b } -// MustString requires parameter value to exist to be bind to string variable. Returns error when value does not exist +// MustString requires parameter value to exist to bind to string variable. Returns error when value does not exist func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder { if b.failFast && b.errors != nil { return b @@ -270,7 +275,7 @@ func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder { return b } -// MustStrings requires parameter values to exist to be bind to slice of string variables. Returns error when value does not exist +// MustStrings requires parameter values to exist to bind to slice of string variables. Returns error when value does not exist func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder { if b.failFast && b.errors != nil { return b @@ -302,7 +307,7 @@ func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler) return b } -// MustBindUnmarshaler requires parameter value to exist to be bind to destination implementing BindUnmarshaler interface. +// MustBindUnmarshaler requires parameter value to exist to bind to destination implementing BindUnmarshaler interface. // Returns error when value does not exist func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder { if b.failFast && b.errors != nil { @@ -321,13 +326,85 @@ func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshal return b } +// JSONUnmarshaler binds parameter to destination implementing json.Unmarshaler interface +func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) + } + return b +} + +// MustJSONUnmarshaler requires parameter value to exist to bind to destination implementing json.Unmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) + } + return b +} + +// TextUnmarshaler binds parameter to destination implementing encoding.TextUnmarshaler interface +func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalText([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) + } + return b +} + +// MustTextUnmarshaler requires parameter value to exist to bind to destination implementing encoding.TextUnmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalText([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) + } + return b +} + // BindWithDelimiter binds parameter to destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { return b.bindWithDelimiter(sourceParam, dest, delimiter, false) } -// MustBindWithDelimiter requires parameter value to exist to be bind destination by suitable conversion function. +// MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { return b.bindWithDelimiter(sourceParam, dest, delimiter, true) @@ -376,7 +453,7 @@ func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder { return b.intValue(sourceParam, dest, 64, false) } -// MustInt64 requires parameter value to exist to be bind to int64 variable. Returns error when value does not exist +// MustInt64 requires parameter value to exist to bind to int64 variable. Returns error when value does not exist func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder { return b.intValue(sourceParam, dest, 64, true) } @@ -386,7 +463,7 @@ func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder { return b.intValue(sourceParam, dest, 32, false) } -// MustInt32 requires parameter value to exist to be bind to int32 variable. Returns error when value does not exist +// MustInt32 requires parameter value to exist to bind to int32 variable. Returns error when value does not exist func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder { return b.intValue(sourceParam, dest, 32, true) } @@ -396,7 +473,7 @@ func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder { return b.intValue(sourceParam, dest, 16, false) } -// MustInt16 requires parameter value to exist to be bind to int16 variable. Returns error when value does not exist +// MustInt16 requires parameter value to exist to bind to int16 variable. Returns error when value does not exist func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder { return b.intValue(sourceParam, dest, 16, true) } @@ -406,7 +483,7 @@ func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder { return b.intValue(sourceParam, dest, 8, false) } -// MustInt8 requires parameter value to exist to be bind to int8 variable. Returns error when value does not exist +// MustInt8 requires parameter value to exist to bind to int8 variable. Returns error when value does not exist func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder { return b.intValue(sourceParam, dest, 8, true) } @@ -416,7 +493,7 @@ func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder { return b.intValue(sourceParam, dest, 0, false) } -// MustInt requires parameter value to exist to be bind to int variable. Returns error when value does not exist +// MustInt requires parameter value to exist to bind to int variable. Returns error when value does not exist func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder { return b.intValue(sourceParam, dest, 0, true) } @@ -544,7 +621,7 @@ func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt64s requires parameter value to exist to be bind to int64 slice variable. Returns error when value does not exist +// MustInt64s requires parameter value to exist to bind to int64 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -554,7 +631,7 @@ func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt32s requires parameter value to exist to be bind to int32 slice variable. Returns error when value does not exist +// MustInt32s requires parameter value to exist to bind to int32 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -564,7 +641,7 @@ func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt16s requires parameter value to exist to be bind to int16 slice variable. Returns error when value does not exist +// MustInt16s requires parameter value to exist to bind to int16 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -574,7 +651,7 @@ func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt8s requires parameter value to exist to be bind to int8 slice variable. Returns error when value does not exist +// MustInt8s requires parameter value to exist to bind to int8 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -584,7 +661,7 @@ func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInts requires parameter value to exist to be bind to int slice variable. Returns error when value does not exist +// MustInts requires parameter value to exist to bind to int slice variable. Returns error when value does not exist func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -594,7 +671,7 @@ func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder { return b.uintValue(sourceParam, dest, 64, false) } -// MustUint64 requires parameter value to exist to be bind to uint64 variable. Returns error when value does not exist +// MustUint64 requires parameter value to exist to bind to uint64 variable. Returns error when value does not exist func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder { return b.uintValue(sourceParam, dest, 64, true) } @@ -604,7 +681,7 @@ func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder { return b.uintValue(sourceParam, dest, 32, false) } -// MustUint32 requires parameter value to exist to be bind to uint32 variable. Returns error when value does not exist +// MustUint32 requires parameter value to exist to bind to uint32 variable. Returns error when value does not exist func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder { return b.uintValue(sourceParam, dest, 32, true) } @@ -614,7 +691,7 @@ func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder { return b.uintValue(sourceParam, dest, 16, false) } -// MustUint16 requires parameter value to exist to be bind to uint16 variable. Returns error when value does not exist +// MustUint16 requires parameter value to exist to bind to uint16 variable. Returns error when value does not exist func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder { return b.uintValue(sourceParam, dest, 16, true) } @@ -624,7 +701,7 @@ func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder { return b.uintValue(sourceParam, dest, 8, false) } -// MustUint8 requires parameter value to exist to be bind to uint8 variable. Returns error when value does not exist +// MustUint8 requires parameter value to exist to bind to uint8 variable. Returns error when value does not exist func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder { return b.uintValue(sourceParam, dest, 8, true) } @@ -634,7 +711,7 @@ func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder { return b.uintValue(sourceParam, dest, 8, false) } -// MustByte requires parameter value to exist to be bind to byte variable. Returns error when value does not exist +// MustByte requires parameter value to exist to bind to byte variable. Returns error when value does not exist func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder { return b.uintValue(sourceParam, dest, 8, true) } @@ -644,7 +721,7 @@ func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder { return b.uintValue(sourceParam, dest, 0, false) } -// MustUint requires parameter value to exist to be bind to uint variable. Returns error when value does not exist +// MustUint requires parameter value to exist to bind to uint variable. Returns error when value does not exist func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder { return b.uintValue(sourceParam, dest, 0, true) } @@ -772,7 +849,7 @@ func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint64s requires parameter value to exist to be bind to uint64 slice variable. Returns error when value does not exist +// MustUint64s requires parameter value to exist to bind to uint64 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -782,7 +859,7 @@ func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint32s requires parameter value to exist to be bind to uint32 slice variable. Returns error when value does not exist +// MustUint32s requires parameter value to exist to bind to uint32 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -792,7 +869,7 @@ func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint16s requires parameter value to exist to be bind to uint16 slice variable. Returns error when value does not exist +// MustUint16s requires parameter value to exist to bind to uint16 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -802,7 +879,7 @@ func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint8s requires parameter value to exist to be bind to uint8 slice variable. Returns error when value does not exist +// MustUint8s requires parameter value to exist to bind to uint8 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -812,7 +889,7 @@ func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUints requires parameter value to exist to be bind to uint slice variable. Returns error when value does not exist +// MustUints requires parameter value to exist to bind to uint slice variable. Returns error when value does not exist func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -822,7 +899,7 @@ func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder { return b.boolValue(sourceParam, dest, false) } -// MustBool requires parameter value to exist to be bind to bool variable. Returns error when value does not exist +// MustBool requires parameter value to exist to bind to bool variable. Returns error when value does not exist func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder { return b.boolValue(sourceParam, dest, true) } @@ -887,7 +964,7 @@ func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder { return b.boolsValue(sourceParam, dest, false) } -// MustBools requires parameter values to exist to be bind to slice of bool variables. Returns error when values does not exist +// MustBools requires parameter values to exist to bind to slice of bool variables. Returns error when values does not exist func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder { return b.boolsValue(sourceParam, dest, true) } @@ -897,7 +974,7 @@ func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder { return b.floatValue(sourceParam, dest, 64, false) } -// MustFloat64 requires parameter value to exist to be bind to float64 variable. Returns error when value does not exist +// MustFloat64 requires parameter value to exist to bind to float64 variable. Returns error when value does not exist func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder { return b.floatValue(sourceParam, dest, 64, true) } @@ -907,7 +984,7 @@ func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder { return b.floatValue(sourceParam, dest, 32, false) } -// MustFloat32 requires parameter value to exist to be bind to float32 variable. Returns error when value does not exist +// MustFloat32 requires parameter value to exist to bind to float32 variable. Returns error when value does not exist func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder { return b.floatValue(sourceParam, dest, 32, true) } @@ -992,7 +1069,7 @@ func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder return b.floatsValue(sourceParam, dest, false) } -// MustFloat64s requires parameter values to exist to be bind to slice of float64 variables. Returns error when values does not exist +// MustFloat64s requires parameter values to exist to bind to slice of float64 variables. Returns error when values does not exist func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder { return b.floatsValue(sourceParam, dest, true) } @@ -1002,7 +1079,7 @@ func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder return b.floatsValue(sourceParam, dest, false) } -// MustFloat32s requires parameter values to exist to be bind to slice of float32 variables. Returns error when values does not exist +// MustFloat32s requires parameter values to exist to bind to slice of float32 variables. Returns error when values does not exist func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder { return b.floatsValue(sourceParam, dest, true) } @@ -1012,7 +1089,7 @@ func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) * return b.time(sourceParam, dest, layout, false) } -// MustTime requires parameter value to exist to be bind to time.Time variable. Returns error when value does not exist +// MustTime requires parameter value to exist to bind to time.Time variable. Returns error when value does not exist func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder { return b.time(sourceParam, dest, layout, true) } @@ -1043,7 +1120,7 @@ func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string return b.times(sourceParam, dest, layout, false) } -// MustTimes requires parameter values to exist to be bind to slice of time.Time variables. Returns error when values does not exist +// MustTimes requires parameter values to exist to bind to slice of time.Time variables. Returns error when values does not exist func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder { return b.times(sourceParam, dest, layout, true) } @@ -1084,7 +1161,7 @@ func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBi return b.duration(sourceParam, dest, false) } -// MustDuration requires parameter value to exist to be bind to time.Duration variable. Returns error when value does not exist +// MustDuration requires parameter value to exist to bind to time.Duration variable. Returns error when value does not exist func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder { return b.duration(sourceParam, dest, true) } @@ -1115,7 +1192,7 @@ func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *Valu return b.durationsValue(sourceParam, dest, false) } -// MustDurations requires parameter values to exist to be bind to slice of time.Duration variables. Returns error when values does not exist +// MustDurations requires parameter values to exist to bind to slice of time.Duration variables. Returns error when values does not exist func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder { return b.durationsValue(sourceParam, dest, true) } @@ -1161,10 +1238,10 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim // Note: // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, false, false) + return b.unixTime(sourceParam, dest, false, time.Second) } -// MustUnixTime requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding // to the given Unix time). Returns error when value does not exist. // // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 @@ -1172,10 +1249,31 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder // Note: // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, true, false) + return b.unixTime(sourceParam, dest, true, time.Second) } -// UnixTimeNano binds parameter to time.Time variable (in local Time corresponding to the given Unix time in nano second precision). +// UnixTimeMilli binds parameter to time.Time variable (in local time corresponding to the given Unix time in millisecond precision). +// +// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, false, time.Millisecond) +} + +// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding +// to the given Unix time in millisecond precision). Returns error when value does not exist. +// +// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, true, time.Millisecond) +} + +// UnixTimeNano binds parameter to time.Time variable (in local time corresponding to the given Unix time in nanosecond precision). // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 // Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 @@ -1185,10 +1283,10 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal // * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, false, true) + return b.unixTime(sourceParam, dest, false, time.Nanosecond) } -// MustUnixTimeNano requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding // to the given Unix time value in nano second precision). Returns error when value does not exist. // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 @@ -1199,10 +1297,10 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal // * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, true, true) + return b.unixTime(sourceParam, dest, true, time.Nanosecond) } -func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, isNano bool) *ValueBinder { +func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, precision time.Duration) *ValueBinder { if b.failFast && b.errors != nil { return b } @@ -1221,10 +1319,13 @@ func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExi return b } - if isNano { - *dest = time.Unix(0, n) - } else { + switch precision { + case time.Second: *dest = time.Unix(n, 0) + case time.Millisecond: + *dest = time.Unix(n/1e3, (n%1e3)*1e6) // TODO: time.UnixMilli(n) exists since Go1.17 switch to that when min version allows + case time.Nanosecond: + *dest = time.Unix(0, n) } return b } diff --git a/binder_test.go b/binder_test.go index f57da32d..3c17057c 100644 --- a/binder_test.go +++ b/binder_test.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/stretchr/testify/assert" "io" + "math/big" "net/http" "net/http/httptest" "strconv" @@ -55,7 +56,7 @@ func TestBindingError_Error(t *testing.T) { func TestBindingError_ErrorJSON(t *testing.T) { err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) - resp, err := json.Marshal(err) + resp, _ := json.Marshal(err) assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp)) } @@ -2188,6 +2189,188 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) { } } +func TestValueBinder_JSONUnmarshaler(t *testing.T) { + example := big.NewInt(999) + + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue big.Int + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: big.Int{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: big.Int{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=xxx", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest big.Int + var err error + if tc.whenMust { + err = b.MustJSONUnmarshaler("param", &dest).BindError() + } else { + err = b.JSONUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_TextUnmarshaler(t *testing.T) { + example := big.NewInt(999) + + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue big.Int + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: big.Int{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: big.Int{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=xxx", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest big.Int + var err error + if tc.whenMust { + err = b.MustTextUnmarshaler("param", &dest).BindError() + } else { + err = b.TextUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestValueBinder_BindWithDelimiter_types(t *testing.T) { var testCases = []struct { name string @@ -2530,6 +2713,97 @@ func TestValueBinder_UnixTime(t *testing.T) { } } +func TestValueBinder_UnixTimeMilli(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140 + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Time + expectError string + }{ + { + name: "ok, binds value, unix time in milliseconds", + whenURL: "/search?param=1647184410140¶m=1647184410199", + expectValue: exampleTime, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: time.Time{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1647184410140¶m=1647184410199", + expectValue: exampleTime, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: time.Time{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustUnixTimeMilli("param", &dest).BindError() + } else { + err = b.UnixTimeMilli("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) + assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestValueBinder_UnixTimeNano(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603 exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789 diff --git a/context.go b/context.go index a397fba7..5f4a2010 100644 --- a/context.go +++ b/context.go @@ -55,6 +55,16 @@ type Context interface { // PathParam returns path parameter by name. PathParam(name string) string + // PathParamDefault returns the path parameter or default value for the provided name. + // + // Notes for DefaultRouter implementation: + // Path parameter could be empty for cases like that: + // * route `/release-:version/bin` and request URL is `/release-/bin` + // * route `/api/:version/image.jpg` and request URL is `/api//image.jpg` + // but not when path parameter is last part of route path + // * route `/download/file.:ext` will not match request `/download/file.` + PathParamDefault(name string, defaultValue string) string + // PathParams returns path parameter values. PathParams() PathParams @@ -176,6 +186,9 @@ type Context interface { Redirect(code int, url string) error // Echo returns the `Echo` instance. + // + // WARNING: Remember that Echo public fields and methods are coroutine safe ONLY when you are NOT mutating them + // anywhere in your code after Echo server has started. Echo() *Echo } @@ -379,7 +392,10 @@ func (c *DefaultContext) PathParam(name string) string { return c.pathParams.Get(name, "") } -// PathParamDefault does not exist as expecting empty path param makes no sense +// PathParamDefault returns the path parameter or default value for the provided name. +func (c *DefaultContext) PathParamDefault(name, defaultValue string) string { + return c.pathParams.Get(name, defaultValue) +} // PathParams returns path parameter values. func (c *DefaultContext) PathParams() PathParams { @@ -406,7 +422,8 @@ func (c *DefaultContext) QueryParam(name string) string { } // QueryParamDefault returns the query param or default value for the provided name. -// Note: QueryParamDefault does not distinguish if form had no value by that name or value was empty string +// Note: QueryParamDefault does not distinguish if query had no value by that name or value was empty string +// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamDefault("search", "1")` func (c *DefaultContext) QueryParamDefault(name, defaultValue string) string { value := c.QueryParam(name) if value == "" { diff --git a/context_test.go b/context_test.go index dca680f9..0df10539 100644 --- a/context_test.go +++ b/context_test.go @@ -448,22 +448,142 @@ func TestContextCookie(t *testing.T) { assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } -func TestContextPathParam(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, nil).(*DefaultContext) - - params := &PathParams{ - {Name: "uid", Value: "101"}, - {Name: "fid", Value: "501"}, +func TestContext_PathParams(t *testing.T) { + var testCases = []struct { + name string + given *PathParams + expect PathParams + }{ + { + name: "param exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + expect: PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + }, + { + name: "params is empty", + given: &PathParams{}, + expect: PathParams{}, + }, } - // ParamNames - c.pathParams = params - assert.EqualValues(t, *params, c.PathParams()) - // Param - assert.Equal(t, "501", c.PathParam("fid")) - assert.Equal(t, "", c.PathParam("undefined")) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) + + c.(RoutableContext).SetRawPathParams(tc.given) + + assert.EqualValues(t, tc.expect, c.PathParams()) + }) + } +} + +func TestContext_PathParam(t *testing.T) { + var testCases = []struct { + name string + given *PathParams + whenParamName string + expect string + }{ + { + name: "param exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + expect: "101", + }, + { + name: "multiple same param values exists - return first", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "uid", Value: "202"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + expect: "101", + }, + { + name: "param does not exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + }, + whenParamName: "nope", + expect: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) + + c.(RoutableContext).SetRawPathParams(tc.given) + + assert.EqualValues(t, tc.expect, c.PathParam(tc.whenParamName)) + }) + } +} + +func TestContext_PathParamDefault(t *testing.T) { + var testCases = []struct { + name string + given *PathParams + whenParamName string + whenDefaultValue string + expect string + }{ + { + name: "param exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + whenDefaultValue: "999", + expect: "101", + }, + { + name: "param exists and is empty", + given: &PathParams{ + {Name: "uid", Value: ""}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + whenDefaultValue: "999", + expect: "", // <-- this is different from QueryParamDefault behaviour + }, + { + name: "param does not exists", + given: &PathParams{ + {Name: "uid", Value: "101"}, + }, + whenParamName: "nope", + whenDefaultValue: "999", + expect: "999", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) + + c.(RoutableContext).SetRawPathParams(tc.given) + + assert.EqualValues(t, tc.expect, c.PathParamDefault(tc.whenParamName, tc.whenDefaultValue)) + }) + } } func TestContextGetAndSetParam(t *testing.T) { @@ -568,27 +688,129 @@ func TestContextFormValue(t *testing.T) { assert.Error(t, err) } -func TestContextQueryParam(t *testing.T) { - q := make(url.Values) - q.Set("name", "Jon Snow") - q.Set("email", "jon@labstack.com") - req := httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) - e := New() - c := e.NewContext(req, nil) +func TestContext_QueryParams(t *testing.T) { + var testCases = []struct { + name string + givenURL string + expect url.Values + }{ + { + name: "multiple values in url", + givenURL: "/?test=1&test=2&email=jon%40labstack.com", + expect: url.Values{ + "test": []string{"1", "2"}, + "email": []string{"jon@labstack.com"}, + }, + }, + { + name: "single value in url", + givenURL: "/?nope=1", + expect: url.Values{ + "nope": []string{"1"}, + }, + }, + { + name: "no query params in url", + givenURL: "/?", + expect: url.Values{}, + }, + } - // QueryParam - assert.Equal(t, "Jon Snow", c.QueryParam("name")) - assert.Equal(t, "jon@labstack.com", c.QueryParam("email")) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) - // QueryParamDefault - assert.Equal(t, "Jon Snow", c.QueryParamDefault("name", "nope")) - assert.Equal(t, "default", c.QueryParamDefault("missing", "default")) + assert.Equal(t, tc.expect, c.QueryParams()) + }) + } +} - // QueryParams - assert.Equal(t, url.Values{ - "name": []string{"Jon Snow"}, - "email": []string{"jon@labstack.com"}, - }, c.QueryParams()) +func TestContext_QueryParam(t *testing.T) { + var testCases = []struct { + name string + givenURL string + whenParamName string + expect string + }{ + { + name: "value exists in url", + givenURL: "/?test=1", + whenParamName: "test", + expect: "1", + }, + { + name: "multiple values exists in url", + givenURL: "/?test=9&test=8", + whenParamName: "test", + expect: "9", // <-- first value in returned + }, + { + name: "value does not exists in url", + givenURL: "/?nope=1", + whenParamName: "test", + expect: "", + }, + { + name: "value is empty in url", + givenURL: "/?test=", + whenParamName: "test", + expect: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) + + assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName)) + }) + } +} + +func TestContext_QueryParamDefault(t *testing.T) { + var testCases = []struct { + name string + givenURL string + whenParamName string + whenDefaultValue string + expect string + }{ + { + name: "value exists in url", + givenURL: "/?test=1", + whenParamName: "test", + whenDefaultValue: "999", + expect: "1", + }, + { + name: "value does not exists in url", + givenURL: "/?nope=1", + whenParamName: "test", + whenDefaultValue: "999", + expect: "999", + }, + { + name: "value is empty in url", + givenURL: "/?test=", + whenParamName: "test", + whenDefaultValue: "999", + expect: "999", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) + + assert.Equal(t, tc.expect, c.QueryParamDefault(tc.whenParamName, tc.whenDefaultValue)) + }) + } } func TestContextFormFile(t *testing.T) { diff --git a/echo.go b/echo.go index 29fa2281..c43d26e2 100644 --- a/echo.go +++ b/echo.go @@ -49,12 +49,15 @@ import ( "os" "os/signal" "path/filepath" + "runtime" "strings" "sync" ) // Echo is the top-level framework instance. -// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics. +// +// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics. This is very likely +// to happen when you access Echo instances through Context.Echo() method. type Echo struct { // premiddleware are middlewares that are run for every request before routing is done premiddleware []MiddlewareFunc @@ -66,8 +69,8 @@ type Echo struct { routerCreator func(e *Echo) Router contextPool sync.Pool - // contextPathParamAllocSize holds maximum parameter count for all added routes. This is necessary info for context - // creation time so we can allocate path parameter values slice. + // contextPathParamAllocSize holds maximum parameter count for all added routes. This is necessary info at context + // creation moment so we can allocate path parameter values slice with correct size. contextPathParamAllocSize int // NewContextFunc allows using custom context implementations, instead of default *echo.context @@ -150,6 +153,8 @@ const ( PROPFIND = "PROPFIND" // REPORT Method can be used to get information about a resource, see rfc 3253 REPORT = "REPORT" + // RouteNotFound is special method type for routes handling "route not found" (404) cases + RouteNotFound = "echo_route_not_found" ) // Headers @@ -181,12 +186,14 @@ const ( HeaderXForwardedSsl = "X-Forwarded-Ssl" HeaderXUrlScheme = "X-Url-Scheme" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" - HeaderXRealIP = "X-Real-IP" - HeaderXRequestID = "X-Request-ID" - HeaderXCorrelationID = "X-Correlation-ID" + HeaderXRealIP = "X-Real-Ip" + HeaderXRequestID = "X-Request-Id" + HeaderXCorrelationID = "X-Correlation-Id" HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" HeaderOrigin = "Origin" + HeaderCacheControl = "Cache-Control" + HeaderConnection = "Connection" // Access control HeaderAccessControlRequestMethod = "Access-Control-Request-Method" @@ -403,6 +410,16 @@ func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo return e.Add(http.MethodTrace, path, h, m...) } +// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases) +// for current request URL. +// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with +// wildcard/match-any character (`/*`, `/download/*` etc). +// +// Example: `e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { + return e.Add(RouteNotFound, path, h, m...) +} + // Any registers a new route for all supported HTTP methods and path with matching handler // in the router with optional route-level middleware. Panics on error. func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { @@ -707,20 +724,26 @@ func newDefaultFS() *defaultFS { dir, _ := os.Getwd() return &defaultFS{ prefix: dir, - fs: os.DirFS(dir), + fs: nil, } } func (fs defaultFS) Open(name string) (fs.File, error) { + if fs.fs == nil { + return os.Open(name) + } return fs.fs.Open(name) } func subFS(currentFs fs.FS, root string) (fs.FS, error) { root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows if dFS, ok := currentFs.(*defaultFS); ok { - // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS to - // allow cases when root is given as `../somepath` which is not valid for fs.FS - root = filepath.Join(dFS.prefix, root) + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. + // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we + // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs + if isRelativePath(root) { + root = filepath.Join(dFS.prefix, root) + } return &defaultFS{ prefix: root, fs: os.DirFS(root), @@ -729,6 +752,21 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) { return fs.Sub(currentFs, root) } +func isRelativePath(path string) bool { + if path == "" { + return true + } + if path[0] == '/' { + return false + } + if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 { + // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names + // https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats + return false + } + return true +} + // MustSubFS creates sub FS from current filesystem or panic on failure. // Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. // diff --git a/echo_test.go b/echo_test.go index ba21014e..b75bd253 100644 --- a/echo_test.go +++ b/echo_test.go @@ -336,16 +336,54 @@ func TestEchoStaticRedirectIndex(t *testing.T) { } func TestEchoFile(t *testing.T) { - e := New() - ri := e.File("/walle", "_fixture/images/walle.png") - assert.Equal(t, http.MethodGet, ri.Method()) - assert.Equal(t, "/walle", ri.Path()) - assert.Equal(t, "GET:/walle", ri.Name()) - assert.Nil(t, ri.Params()) + var testCases = []struct { + name string + givenPath string + givenFile string + whenPath string + expectCode int + expectStartsWith string + }{ + { + name: "ok", + givenPath: "/walle", + givenFile: "_fixture/images/walle.png", + whenPath: "/walle", + expectCode: http.StatusOK, + expectStartsWith: string([]byte{0x89, 0x50, 0x4e}), + }, + { + name: "ok with relative path", + givenPath: "/", + givenFile: "./go.mod", + whenPath: "/", + expectCode: http.StatusOK, + expectStartsWith: "module github.com/labstack/echo/v", + }, + { + name: "nok file does not exist", + givenPath: "/", + givenFile: "./this-file-does-not-exist", + whenPath: "/", + expectCode: http.StatusNotFound, + expectStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } - c, b := request(http.MethodGet, "/walle", e) - assert.Equal(t, http.StatusOK, c) - assert.NotEmpty(t, b) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() // we are using echo.defaultFS instance + e.File(tc.givenPath, tc.givenFile) + + c, b := request(http.MethodGet, tc.whenPath, e) + assert.Equal(t, tc.expectCode, c) + + if len(b) > len(tc.expectStartsWith) { + b = b[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, b) + }) + } } func TestEchoMiddleware(t *testing.T) { @@ -880,6 +918,70 @@ func TestEchoGroup(t *testing.T) { assert.Equal(t, "023", buf.String()) } +func TestEcho_RouteNotFound(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectCode int + }{ + { + name: "404, route to static not found handler /a/c/xx", + whenURL: "/a/c/xx", + expectRoute: "GET /a/c/xx", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to path param not found handler /a/:file", + whenURL: "/a/echo.exe", + expectRoute: "GET /a/:file", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to any not found handler /*", + whenURL: "/b/echo.exe", + expectRoute: "GET /*", + expectCode: http.StatusNotFound, + }, + { + name: "200, route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "GET /a/c/df", + expectCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + e.GET("/", okHandler) + e.GET("/a/c/df", okHandler) + e.GET("/a/b*", okHandler) + e.PUT("/*", okHandler) + + e.RouteNotFound("/a/c/xx", notFoundHandler) // static + e.RouteNotFound("/a/:file", notFoundHandler) // param + e.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) + }) + } +} + func TestEchoNotFound(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/files", nil) diff --git a/group.go b/group.go index 4f04a73a..b9df5af9 100644 --- a/group.go +++ b/group.go @@ -157,6 +157,13 @@ func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo return g.Add(http.MethodGet, path, handler, middleware...) } +// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group. +// +// Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { + return g.Add(RouteNotFound, path, h, m...) +} + // Add implements `Echo#Add()` for sub-routes within the Group. Panics on error. func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { ri, err := g.AddRoute(Route{ diff --git a/group_test.go b/group_test.go index 3914c0bd..bd215726 100644 --- a/group_test.go +++ b/group_test.go @@ -303,6 +303,71 @@ func TestGroup_TRACE(t *testing.T) { assert.Equal(t, `OK`, body) } +func TestGroup_RouteNotFound(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectCode int + }{ + { + name: "404, route to static not found handler /group/a/c/xx", + whenURL: "/group/a/c/xx", + expectRoute: "GET /group/a/c/xx", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to path param not found handler /group/a/:file", + whenURL: "/group/a/echo.exe", + expectRoute: "GET /group/a/:file", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to any not found handler /group/*", + whenURL: "/group/b/echo.exe", + expectRoute: "GET /group/*", + expectCode: http.StatusNotFound, + }, + { + name: "200, route /group/a/c/df to /group/a/c/df", + whenURL: "/group/a/c/df", + expectRoute: "GET /group/a/c/df", + expectCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/group") + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + g.GET("/", okHandler) + g.GET("/a/c/df", okHandler) + g.GET("/a/b*", okHandler) + g.PUT("/*", okHandler) + + g.RouteNotFound("/a/c/xx", notFoundHandler) // static + g.RouteNotFound("/a/:file", notFoundHandler) // param + g.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) + }) + } +} + func TestGroup_Any(t *testing.T) { e := New() diff --git a/ip.go b/ip.go index 39cb421f..46d464cf 100644 --- a/ip.go +++ b/ip.go @@ -6,6 +6,130 @@ import ( "strings" ) +/** +By: https://github.com/tmshn (See: https://github.com/labstack/echo/pull/1478 , https://github.com/labstack/echox/pull/134 ) +Source: https://echo.labstack.com/guide/ip-address/ + +IP address plays fundamental role in HTTP; it's used for access control, auditing, geo-based access analysis and more. +Echo provides handy method [`Context#RealIP()`](https://godoc.org/github.com/labstack/echo#Context) for that. + +However, it is not trivial to retrieve the _real_ IP address from requests especially when you put L7 proxies before the application. +In such situation, _real_ IP needs to be relayed on HTTP layer from proxies to your app, but you must not trust HTTP headers unconditionally. +Otherwise, you might give someone a chance of deceiving you. **A security risk!** + +To retrieve IP address reliably/securely, you must let your application be aware of the entire architecture of your infrastructure. +In Echo, this can be done by configuring `Echo#IPExtractor` appropriately. +This guides show you why and how. + +> Note: if you dont' set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice. + +Let's start from two questions to know the right direction: + +1. Do you put any HTTP (L7) proxy in front of the application? + - It includes both cloud solutions (such as AWS ALB or GCP HTTP LB) and OSS ones (such as Nginx, Envoy or Istio ingress gateway). +2. If yes, what HTTP header do your proxies use to pass client IP to the application? + +## Case 1. With no proxy + +If you put no proxy (e.g.: directory facing to the internet), all you need to (and have to) see is IP address from network layer. +Any HTTP header is untrustable because the clients have full control what headers to be set. + +In this case, use `echo.ExtractIPDirect()`. + +```go +e.IPExtractor = echo.ExtractIPDirect() +``` + +## Case 2. With proxies using `X-Forwarded-For` header + +[`X-Forwared-For` (XFF)](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) is the popular header +to relay clients' IP addresses. +At each hop on the proxies, they append the request IP address at the end of the header. + +Following example diagram illustrates this behavior. + +```text +┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ +│ "Origin" │───────────>│ Proxy 1 │───────────>│ Proxy 2 │───────────>│ Your app │ +│ (IP: a) │ │ (IP: b) │ │ (IP: c) │ │ │ +└──────────┘ └──────────┘ └──────────┘ └──────────┘ + +Case 1. +XFF: "" "a" "a, b" + ~~~~~~ +Case 2. +XFF: "x" "x, a" "x, a, b" + ~~~~~~~~~ + ↑ What your app will see +``` + +In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is +configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructre". +In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`. + +In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`. + +```go +e.IPExtractor = echo.ExtractIPFromXFFHeader() +``` + +By default, it trusts internal IP addresses (loopback, link-local unicast, private-use and unique local address +from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and +[RFC4193](https://tools.ietf.org/html/rfc4193)). +To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. + +E.g.: + +```go +e.IPExtractor = echo.ExtractIPFromXFFHeader( + TrustLinkLocal(false), + TrustIPRanges(lbIPRange), +) +``` + +- Ref: https://godoc.org/github.com/labstack/echo#TrustOption + +## Case 3. With proxies using `X-Real-IP` header + +`X-Real-IP` is another HTTP header to relay clients' IP addresses, but it carries only one address unlike XFF. + +If your proxies set this header, use `ExtractIPFromRealIPHeader(...TrustOption)`. + +```go +e.IPExtractor = echo.ExtractIPFromRealIPHeader() +``` + +Again, it trusts internal IP addresses by default (loopback, link-local unicast, private-use and unique local address +from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and +[RFC4193](https://tools.ietf.org/html/rfc4193)). +To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. + +- Ref: https://godoc.org/github.com/labstack/echo#TrustOption + +> **Never forget** to configure the outermost proxy (i.e.; at the edge of your infrastructure) **not to pass through incoming headers**. +> Otherwise there is a chance of fraud, as it is what clients can control. + +## About default behavior + +In default behavior, Echo sees all of first XFF header, X-Real-IP header and IP from network layer. + +As you might already notice, after reading this article, this is not good. +Sole reason this is default is just backward compatibility. + +## Private IP ranges + +See: https://en.wikipedia.org/wiki/Private_network + +Private IPv4 address ranges (RFC 1918): +* 10.0.0.0 – 10.255.255.255 (24-bit block) +* 172.16.0.0 – 172.31.255.255 (20-bit block) +* 192.168.0.0 – 192.168.255.255 (16-bit block) + +Private IPv6 address ranges: +* fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) + +*/ + type ipChecker struct { trustLoopback bool trustLinkLocal bool @@ -52,6 +176,7 @@ func newIPChecker(configs []TrustOption) *ipChecker { return checker } +// Go1.16+ added `ip.IsPrivate()` but until that use this implementation func isPrivateIPRange(ip net.IP) bool { if ip4 := ip.To4(); ip4 != nil { return ip4[0] == 10 || @@ -87,10 +212,12 @@ type IPExtractor func(*http.Request) string // ExtractIPDirect extracts IP address using actual IP address. // Use this if your server faces to internet directory (i.e.: uses no proxy). func ExtractIPDirect() IPExtractor { - return func(req *http.Request) string { - ra, _, _ := net.SplitHostPort(req.RemoteAddr) - return ra - } + return extractIP +} + +func extractIP(req *http.Request) string { + ra, _, _ := net.SplitHostPort(req.RemoteAddr) + return ra } // ExtractIPFromRealIPHeader extracts IP address using x-real-ip header. @@ -98,14 +225,13 @@ func ExtractIPDirect() IPExtractor { func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { - directIP := ExtractIPDirect()(req) realIP := req.Header.Get(HeaderXRealIP) if realIP != "" { - if ip := net.ParseIP(directIP); ip != nil && checker.trust(ip) { + if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { return realIP } } - return directIP + return extractIP(req) } } @@ -115,7 +241,7 @@ func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { - directIP := ExtractIPDirect()(req) + directIP := extractIP(req) xffs := req.Header[HeaderXForwardedFor] if len(xffs) == 0 { return directIP diff --git a/ip_test.go b/ip_test.go index 5acc1179..755900d3 100644 --- a/ip_test.go +++ b/ip_test.go @@ -1,235 +1,606 @@ package echo import ( + "github.com/stretchr/testify/assert" "net" "net/http" - "strings" "testing" - - testify "github.com/stretchr/testify/assert" ) -const ( - // For RemoteAddr - ipForRemoteAddrLoopback = "127.0.0.1" // From 127.0.0.0/8 - sampleRemoteAddrLoopback = ipForRemoteAddrLoopback + ":8080" - ipForRemoteAddrExternal = "203.0.113.1" - sampleRemoteAddrExternal = ipForRemoteAddrExternal + ":8080" - // For x-real-ip - ipForRealIP = "203.0.113.10" - // For XFF - ipForXFF1LinkLocal = "169.254.0.101" // From 169.254.0.0/16 - ipForXFF2Private = "192.168.0.102" // From 192.168.0.0/16 - ipForXFF3External = "2001:db8::103" - ipForXFF4Private = "fc00::104" // From fc00::/7 - ipForXFF5External = "198.51.100.105" - ipForXFF6External = "192.0.2.106" - ipForXFFBroken = "this.is.broken.lol" - // keys for test cases - ipTestReqKeyNoHeader = "no header" - ipTestReqKeyRealIPExternal = "x-real-ip; remote addr external" - ipTestReqKeyRealIPInternal = "x-real-ip; remote addr internal" - ipTestReqKeyRealIPAndXFFExternal = "x-real-ip and xff; remote addr external" - ipTestReqKeyRealIPAndXFFInternal = "x-real-ip and xff; remote addr internal" - ipTestReqKeyXFFExternal = "xff; remote addr external" - ipTestReqKeyXFFInternal = "xff; remote addr internal" - ipTestReqKeyBrokenXFF = "broken xff" -) - -var ( - sampleXFF = strings.Join([]string{ - ipForXFF6External, ipForXFF5External, ipForXFF4Private, ipForXFF3External, ipForXFF2Private, ipForXFF1LinkLocal, - }, ", ") - - requests = map[string]*http.Request{ - ipTestReqKeyNoHeader: &http.Request{ - RemoteAddr: sampleRemoteAddrExternal, - }, - ipTestReqKeyRealIPExternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - }, - RemoteAddr: sampleRemoteAddrExternal, - }, - ipTestReqKeyRealIPInternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - }, - RemoteAddr: sampleRemoteAddrLoopback, - }, - ipTestReqKeyRealIPAndXFFExternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - HeaderXForwardedFor: []string{sampleXFF}, - }, - RemoteAddr: sampleRemoteAddrExternal, - }, - ipTestReqKeyRealIPAndXFFInternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - HeaderXForwardedFor: []string{sampleXFF}, - }, - RemoteAddr: sampleRemoteAddrLoopback, - }, - ipTestReqKeyXFFExternal: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{sampleXFF}, - }, - RemoteAddr: sampleRemoteAddrExternal, - }, - ipTestReqKeyXFFInternal: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{sampleXFF}, - }, - RemoteAddr: sampleRemoteAddrLoopback, - }, - ipTestReqKeyBrokenXFF: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{ipForXFFBroken + ", " + ipForXFF1LinkLocal}, - }, - RemoteAddr: sampleRemoteAddrLoopback, - }, +func mustParseCIDR(s string) *net.IPNet { + _, IPNet, err := net.ParseCIDR(s) + if err != nil { + panic(err) } -) + return IPNet +} -func TestExtractIP(t *testing.T) { - _, ipv4AllRange, _ := net.ParseCIDR("0.0.0.0/0") - _, ipv6AllRange, _ := net.ParseCIDR("::/0") - _, ipForXFF3ExternalRange, _ := net.ParseCIDR(ipForXFF3External + "/48") - _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR(ipForRemoteAddrExternal + "/24") - - tests := map[string]*struct { - extractor IPExtractor - expectedIPs map[string]string +func TestIPChecker_TrustOption(t *testing.T) { + var testCases = []struct { + name string + givenOptions []TrustOption + whenIP string + expect bool }{ - "ExtractIPDirect": { - ExtractIPDirect(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + { + name: "ip is within trust range, trusts additional private IPV6 network", + givenOptions: []TrustOption{ + TrustLoopback(false), + TrustLinkLocal(false), + TrustPrivateNet(false), + // this is private IPv6 ip + // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 + // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 + // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 + // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff + TrustIPRange(mustParseCIDR("2001:db8::103/48")), }, + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, }, - "ExtractIPFromRealIPHeader(default)": { - ExtractIPFromRealIPHeader(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForRealIP, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromRealIPHeader(trust only direct-facing proxy)": { - ExtractIPFromRealIPHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRealIP, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromRealIPHeader(trust direct-facing proxy)": { - ExtractIPFromRealIPHeader(TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRealIP, - ipTestReqKeyRealIPInternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFExternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFInternal: ipForRealIP, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(default)": { - ExtractIPFromXFFHeader(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForXFF3External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust only direct-facing proxy)": { - ExtractIPFromXFFHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF1LinkLocal, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForXFF1LinkLocal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust direct-facing proxy)": { - ExtractIPFromXFFHeader(TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF3External, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External, - ipTestReqKeyXFFExternal: ipForXFF3External, - ipTestReqKeyXFFInternal: ipForXFF3External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust everything)": { - // This is similar to legacy behavior, but ignores x-real-ip header. - ExtractIPFromXFFHeader(TrustIPRange(ipv4AllRange), TrustIPRange(ipv6AllRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF6External, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF6External, - ipTestReqKeyXFFExternal: ipForXFF6External, - ipTestReqKeyXFFInternal: ipForXFF6External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust ipForXFF3External)": { - // This trusts private network also after "additional" trust ranges unlike `TrustNProxies(1)` doesn't - ExtractIPFromXFFHeader(TrustIPRange(ipForXFF3ExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF5External, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForXFF5External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + { + name: "ip is within trust range, trusts additional private IPV6 network", + givenOptions: []TrustOption{ + TrustIPRange(mustParseCIDR("2001:db8::103/48")), }, + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, }, } - for name, test := range tests { - t.Run(name, func(t *testing.T) { - assert := testify.New(t) - for key, req := range requests { - actual := test.extractor(req) - expected := test.expectedIPs[key] - assert.Equal(expected, actual, "Request: %s", key) - } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker(tc.givenOptions) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustIPRange(t *testing.T) { + var testCases = []struct { + name string + givenRange string + whenIP string + expect bool + }{ + { + name: "ip is within trust range, IPV6 network range", + // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 + // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 + // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 + // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff + givenRange: "2001:db8::103/48", + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, + }, + { + name: "ip is outside (upper bounds) of trust range, IPV6 network range", + givenRange: "2001:db8::103/48", + whenIP: "2001:0db8:0001:0000:0000:0000:0000:0000", + expect: false, + }, + { + name: "ip is outside (lower bounds) of trust range, IPV6 network range", + givenRange: "2001:db8::103/48", + whenIP: "2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff", + expect: false, + }, + { + name: "ip is within trust range, IPV4 network range", + // CIDR Notation: 8.8.8.8/24 + // Address: 8.8.8.8 + // Range start: 8.8.8.0 + // Range end: 8.8.8.255 + givenRange: "8.8.8.0/24", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "ip is within trust range, IPV4 network range", + // CIDR Notation: 8.8.8.8/24 + // Address: 8.8.8.8 + // Range start: 8.8.8.0 + // Range end: 8.8.8.255 + givenRange: "8.8.8.0/24", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "ip is outside (upper bounds) of trust range, IPV4 network range", + givenRange: "8.8.8.0/24", + whenIP: "8.8.9.0", + expect: false, + }, + { + name: "ip is outside (lower bounds) of trust range, IPV4 network range", + givenRange: "8.8.8.0/24", + whenIP: "8.8.7.255", + expect: false, + }, + { + name: "public ip, trust everything in IPV4 network range", + givenRange: "0.0.0.0/0", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "internal ip, trust everything in IPV4 network range", + givenRange: "0.0.0.0/0", + whenIP: "127.0.10.1", + expect: true, + }, + { + name: "public ip, trust everything in IPV6 network range", + givenRange: "::/0", + whenIP: "2a00:1450:4026:805::200e", + expect: true, + }, + { + name: "internal ip, trust everything in IPV6 network range", + givenRange: "::/0", + whenIP: "0:0:0:0:0:0:0:1", + expect: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cidr := mustParseCIDR(tc.givenRange) + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustLinkLocal(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustIPRange(cidr), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustPrivateNet(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "do not trust public IPv4 address", + whenIP: "8.8.8.8", + expect: false, + }, + { + name: "do not trust public IPv6 address", + whenIP: "2a00:1450:4026:805::200e", + expect: false, + }, + + { // Class A: 10.0.0.0 — 10.255.255.255 + name: "do not trust IPv4 just outside of class A (lower bounds)", + whenIP: "9.255.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class A (upper bounds)", + whenIP: "11.0.0.0", + expect: false, + }, + { + name: "trust IPv4 of class A (lower bounds)", + whenIP: "10.0.0.0", + expect: true, + }, + { + name: "trust IPv4 of class A (upper bounds)", + whenIP: "10.255.255.255", + expect: true, + }, + + { // Class B: 172.16.0.0 — 172.31.255.255 + name: "do not trust IPv4 just outside of class B (lower bounds)", + whenIP: "172.15.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class B (upper bounds)", + whenIP: "172.32.0.0", + expect: false, + }, + { + name: "trust IPv4 of class B (lower bounds)", + whenIP: "172.16.0.0", + expect: true, + }, + { + name: "trust IPv4 of class B (upper bounds)", + whenIP: "172.31.255.255", + expect: true, + }, + + { // Class C: 192.168.0.0 — 192.168.255.255 + name: "do not trust IPv4 just outside of class C (lower bounds)", + whenIP: "192.167.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class C (upper bounds)", + whenIP: "192.169.0.0", + expect: false, + }, + { + name: "trust IPv4 of class C (lower bounds)", + whenIP: "192.168.0.0", + expect: true, + }, + { + name: "trust IPv4 of class C (upper bounds)", + whenIP: "192.168.255.255", + expect: true, + }, + + { // fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) + // splits the address block in two equally sized halves, fc00::/8 and fd00::/8. + // https://en.wikipedia.org/wiki/Unique_local_address + name: "trust IPv6 private address", + whenIP: "fdfc:3514:2cb3:4bd5::", + expect: true, + }, + { + name: "do not trust IPv6 just out of /fd (upper bounds)", + whenIP: "/fe00:0000:0000:0000:0000", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustLinkLocal(false), // disable to avoid interference + + TrustPrivateNet(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustLinkLocal(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "trust link local IPv4 address (lower bounds)", + whenIP: "169.254.0.0", + expect: true, + }, + { + name: "trust link local IPv4 address (upper bounds)", + whenIP: "169.254.255.255", + expect: true, + }, + { + name: "do not trust link local IPv4 address (outside of lower bounds)", + whenIP: "169.253.255.255", + expect: false, + }, + { + name: "do not trust link local IPv4 address (outside of upper bounds)", + whenIP: "169.255.0.0", + expect: false, + }, + { + name: "trust link local IPv6 address ", + whenIP: "fe80::1", + expect: true, + }, + { + name: "do not trust link local IPv6 address ", + whenIP: "fec0::1", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustLinkLocal(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustLoopback(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "trust IPv4 as localhost", + whenIP: "127.0.0.1", + expect: true, + }, + { + name: "trust IPv6 as localhost", + whenIP: "::1", + expect: true, + }, + { + name: "do not trust public ip as localhost", + whenIP: "8.8.8.8", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLinkLocal(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustLoopback(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestExtractIPDirect(t *testing.T) { + var testCases = []struct { + name string + whenRequest http.Request + expectIP string + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from internal IP and has Real-IP header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request is from external IP and has XFF + Real-IP header, extractor still extracts external IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from internal IP and has XFF + Real-IP header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"127.0.0.1"}, + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request is from external IP and has XFF header, extractor still extracts external IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from internal IP and has XFF header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request is from internal IP and has INVALID XFF header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"this.is.broken.lol, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPDirect()(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) + }) + } +} + +func TestExtractIPFromRealIPHeader(t *testing.T) { + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + + var testCases = []struct { + name string + givenTrustOptions []TrustOption + whenRequest http.Request + expectIP string + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has INVALID external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"xxx.yyy.zzz.ccc"}, // <-- this is invalid + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.199", + }, + { + name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, + HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.199", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPFromRealIPHeader(tc.givenTrustOptions...)(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) + }) + } +} + +func TestExtractIPFromXFFHeader(t *testing.T) { + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + + var testCases = []struct { + name string + givenTrustOptions []TrustOption + whenRequest http.Request + expectIP string + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request has INVALID external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"xxx.yyy.zzz.ccc, 127.0.0.2"}, // <-- this is invalid + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"127.0.0.3, 127.0.0.2, 127.0.0.1"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.3", + }, + { + name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"203.0.113.199"}, // <-- this is untrusted + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", + givenTrustOptions: []TrustOption{ + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + // from request its seems that request has been proxied through 6 servers. + // 1) 203.0.1.100 (this is external IP set by 203.0.100.100 which we do not trust - could be spoofed) + // 2) 203.0.100.100 (this is outside of our network but set by 203.0.113.199 which we trust to set correct IPs) + // 3) 203.0.113.199 (we trust, for example maybe our proxy from some other office) + // 4) 192.168.1.100 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products) + // 5) 127.0.0.1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing) + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"203.0.1.100, 203.0.100.100, 203.0.113.199, 192.168.1.100"}, + }, + RemoteAddr: "127.0.0.1:8080", // IP of proxy upstream of our APP + }, + expectIP: "203.0.100.100", // this is first trusted IP in XFF chain + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPFromXFFHeader(tc.givenTrustOptions...)(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) }) } } diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 82e2fbf7..3071eedb 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/base64" "errors" - "fmt" + "net/http" "strconv" "strings" @@ -72,9 +72,11 @@ func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { continue } + // Invalid base64 shouldn't be treated as error + // instead should be treated as invalid client input b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) if errDecode != nil { - lastError = fmt.Errorf("invalid basic auth value: %w", errDecode) + lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode) continue } idx := bytes.IndexByte(b, ':') diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 9580dff0..3d69ae84 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -56,7 +56,7 @@ func TestBasicAuth(t *testing.T) { name: "nok, not base64 Authorization header", givenConfig: defaultConfig, whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, - expectErr: "invalid basic auth value: illegal base64 data at input byte 3", + expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3", }, { name: "nok, missing Authorization header", diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index f367f938..6e7778ea 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -43,6 +43,7 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) { // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) @@ -55,6 +56,7 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) { // Based on content read (overlimit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() diff --git a/middleware/csrf.go b/middleware/csrf.go index acab8790..895a9c63 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -144,7 +144,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { var lastTokenErr error outer: for _, extractor := range extractors { - clientTokens, err := extractor(c) + clientTokens, _, err := extractor(c) if err != nil { lastExtractorErr = err continue diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index c35ed6fa..5651a493 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -79,9 +79,6 @@ func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) { func TestDecompressWithConfig_DefaultConfig(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) h := Decompress()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing @@ -91,10 +88,10 @@ func TestDecompressWithConfig_DefaultConfig(t *testing.T) { // Decompress body := `{"name": "echo"}` gz, _ := gzipString(body) - req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) err := h(c) assert.NoError(t, err) diff --git a/middleware/extractor.go b/middleware/extractor.go index 94134f7e..4ad676f6 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -13,6 +13,24 @@ const ( extractorLimit = 20 ) +// ExtractorSource is type to indicate source for extracted value +type ExtractorSource string + +const ( + // ExtractorSourceHeader means value was extracted from request header + ExtractorSourceHeader ExtractorSource = "header" + // ExtractorSourceQuery means value was extracted from request query parameters + ExtractorSourceQuery ExtractorSource = "query" + // ExtractorSourcePathParam means value was extracted from route path parameters + ExtractorSourcePathParam ExtractorSource = "param" + // ExtractorSourceCookie means value was extracted from request cookies + ExtractorSourceCookie ExtractorSource = "cookie" + // ExtractorSourceForm means value was extracted from request form values + ExtractorSourceForm ExtractorSource = "form" + // ExtractorSourceCustom means value was extracted by custom extractor + ExtractorSourceCustom ExtractorSource = "custom" +) + // ValueExtractorError is error type when middleware extractor is unable to extract value from lookups type ValueExtractorError struct { message string @@ -31,7 +49,7 @@ var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing valu var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"} // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. -type ValuesExtractor func(c echo.Context) ([]string, error) +type ValuesExtractor func(c echo.Context) ([]string, ExtractorSource, error) func createExtractors(lookups string) ([]ValuesExtractor, error) { if lookups == "" { @@ -75,10 +93,10 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { prefixLen := len(valuePrefix) // standard library parses http.Request header keys in canonical form but we may provide something else so fix this header = textproto.CanonicalMIMEHeaderKey(header) - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { values := c.Request().Header.Values(header) if len(values) == 0 { - return nil, errHeaderExtractorValueMissing + return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing } result := make([]string, 0) @@ -100,30 +118,30 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { if len(result) == 0 { if prefixLen > 0 { - return nil, errHeaderExtractorValueInvalid + return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid } - return nil, errHeaderExtractorValueMissing + return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing } - return result, nil + return result, ExtractorSourceHeader, nil } } // valuesFromQuery returns a function that extracts values from the query string. func valuesFromQuery(param string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { result := c.QueryParams()[param] if len(result) == 0 { - return nil, errQueryExtractorValueMissing + return nil, ExtractorSourceQuery, errQueryExtractorValueMissing } else if len(result) > extractorLimit-1 { result = result[:extractorLimit] } - return result, nil + return result, ExtractorSourceQuery, nil } } // valuesFromParam returns a function that extracts values from the url param string. func valuesFromParam(param string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { result := make([]string, 0) for i, p := range c.PathParams() { if param == p.Name { @@ -134,18 +152,18 @@ func valuesFromParam(param string) ValuesExtractor { } } if len(result) == 0 { - return nil, errParamExtractorValueMissing + return nil, ExtractorSourcePathParam, errParamExtractorValueMissing } - return result, nil + return result, ExtractorSourcePathParam, nil } } // valuesFromCookie returns a function that extracts values from the named cookie. func valuesFromCookie(name string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { + return func(c echo.Context) ([]string, ExtractorSource, error) { cookies := c.Cookies() if len(cookies) == 0 { - return nil, errCookieExtractorValueMissing + return nil, ExtractorSourceCookie, errCookieExtractorValueMissing } result := make([]string, 0) @@ -158,26 +176,26 @@ func valuesFromCookie(name string) ValuesExtractor { } } if len(result) == 0 { - return nil, errCookieExtractorValueMissing + return nil, ExtractorSourceCookie, errCookieExtractorValueMissing } - return result, nil + return result, ExtractorSourceCookie, nil } } // valuesFromForm returns a function that extracts values from the form field. func valuesFromForm(name string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { - if parseErr := c.Request().ParseForm(); parseErr != nil { - return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr) + return func(c echo.Context) ([]string, ExtractorSource, error) { + if c.Request().Form == nil { + _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does } values := c.Request().Form[name] if len(values) == 0 { - return nil, errFormExtractorValueMissing + return nil, ExtractorSourceForm, errFormExtractorValueMissing } if len(values) > extractorLimit-1 { values = values[:extractorLimit] } result := append([]string{}, values...) - return result, nil + return result, ExtractorSourceForm, nil } } diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index 439c4d8f..afa776ec 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -1,9 +1,11 @@ package middleware import ( + "bytes" "fmt" "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" + "mime/multipart" "net/http" "net/http/httptest" "net/url" @@ -18,6 +20,7 @@ func TestCreateExtractors(t *testing.T) { givenPathParams echo.PathParams whenLoopups string expectValues []string + expectSource ExtractorSource expectCreateError string expectError string }{ @@ -30,6 +33,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "header:Authorization:Bearer ", expectValues: []string{"token"}, + expectSource: ExtractorSourceHeader, }, { name: "ok, form", @@ -43,6 +47,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "form:name", expectValues: []string{"Jon Snow"}, + expectSource: ExtractorSourceForm, }, { name: "ok, cookie", @@ -53,6 +58,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "cookie:_csrf", expectValues: []string{"token"}, + expectSource: ExtractorSourceCookie, }, { name: "ok, param", @@ -61,6 +67,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "param:id", expectValues: []string{"123"}, + expectSource: ExtractorSourcePathParam, }, { name: "ok, query", @@ -70,6 +77,7 @@ func TestCreateExtractors(t *testing.T) { }, whenLoopups: "query:id", expectValues: []string{"999"}, + expectSource: ExtractorSourceQuery, }, { name: "nok, invalid lookup", @@ -100,8 +108,9 @@ func TestCreateExtractors(t *testing.T) { assert.NoError(t, err) for _, e := range extractors { - values, eErr := e(c) + values, source, eErr := e(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, tc.expectSource, source) if tc.expectError != "" { assert.EqualError(t, eErr, tc.expectError) return @@ -226,8 +235,9 @@ func TestValuesFromHeader(t *testing.T) { extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceHeader, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -287,8 +297,9 @@ func TestValuesFromQuery(t *testing.T) { extractor := valuesFromQuery(tc.whenName) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceQuery, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -366,8 +377,9 @@ func TestValuesFromParam(t *testing.T) { extractor := valuesFromParam(tc.whenName) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourcePathParam, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -446,8 +458,9 @@ func TestValuesFromCookie(t *testing.T) { extractor := valuesFromCookie(tc.whenName) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceCookie, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -483,6 +496,25 @@ func TestValuesFromForm(t *testing.T) { return req } + exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request { + var b bytes.Buffer + w := multipart.NewWriter(&b) + w.WriteField("name", "Jon Snow") + w.WriteField("emails[]", "jon@labstack.com") + if mod != nil { + mod(w) + } + + fw, _ := w.CreateFormFile("upload", "my.file") + fw.Write([]byte(`