mirror of
https://github.com/labstack/echo.git
synced 2025-01-07 23:01:56 +02:00
Merge branch 'v5_alpha' into v5_alpha_labstack
# Conflicts: # .github/workflows/echo.yml # Makefile # context.go # context_test.go # echo.go # echo_test.go # group.go # group_test.go # middleware/basic_auth.go # middleware/basic_auth_test.go # middleware/body_limit_test.go # middleware/decompress_test.go # middleware/extractor.go # middleware/jwt.go # middleware/jwt_external_test.go # middleware/jwt_test.go # middleware/key_auth.go # middleware/key_auth_test.go # middleware/logger.go # middleware/method_override_test.go # middleware/recover.go # middleware/recover_test.go # router.go # router_test.go # server_test.go
This commit is contained in:
commit
74022662be
48
.github/workflows/echo.yml
vendored
48
.github/workflows/echo.yml
vendored
@ -19,7 +19,7 @@ on:
|
||||
- '_fixture/**'
|
||||
- '.github/**'
|
||||
- 'codecov.yml'
|
||||
workflow_dispatch: # to be able to run workflow manually
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@ -29,72 +29,72 @@ jobs:
|
||||
# Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
|
||||
# Echo tests with last four major releases
|
||||
# except v5 starts from 1.17 until there is last four major releases after that
|
||||
go: [1.17]
|
||||
go: [1.17, 1.18]
|
||||
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Set up Go ${{ matrix.go }}
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
ref: ${{ github.ref }}
|
||||
|
||||
- name: Set up Go ${{ matrix.go }}
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Install Dependencies
|
||||
run: go get -v golang.org/x/lint/golint
|
||||
run: |
|
||||
go install golang.org/x/lint/golint@latest
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
- name: Run Tests
|
||||
run: |
|
||||
golint -set_exit_status ./...
|
||||
staticcheck ./...
|
||||
go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./...
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
if: success() && matrix.go == 1.17 && matrix.os == 'ubuntu-latest'
|
||||
uses: codecov/codecov-action@v2
|
||||
if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest'
|
||||
uses: codecov/codecov-action@v1
|
||||
with:
|
||||
token:
|
||||
fail_ci_if_error: false
|
||||
|
||||
benchmark:
|
||||
needs: test
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
go: [1.17]
|
||||
go: [1.18]
|
||||
name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Set up Go ${{ matrix.go }}
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Checkout Code (Previous)
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
ref: ${{ github.base_ref }}
|
||||
path: previous
|
||||
|
||||
- name: Checkout Code (New)
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
path: new
|
||||
|
||||
- name: Set up Go ${{ matrix.go }}
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
|
||||
- name: Install Dependencies
|
||||
run: go get -v golang.org/x/perf/cmd/benchstat
|
||||
run: go install golang.org/x/perf/cmd/benchstat@latest
|
||||
|
||||
- name: Run Benchmark (Previous)
|
||||
run: |
|
||||
cd previous
|
||||
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
|
||||
|
||||
- name: Run Benchmark (New)
|
||||
run: |
|
||||
cd new
|
||||
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
|
||||
|
||||
- name: Run Benchstat
|
||||
run: |
|
||||
benchstat previous/benchmark.txt new/benchmark.txt
|
||||
|
45
CHANGELOG.md
45
CHANGELOG.md
@ -1,5 +1,50 @@
|
||||
# Changelog
|
||||
|
||||
## v4.7.2 - 2022-03-16
|
||||
|
||||
**Fixes**
|
||||
|
||||
* Fix nil pointer exception when calling Start again after address binding error [#2131](https://github.com/labstack/echo/pull/2131)
|
||||
* Fix CSRF middleware not being able to extract token from multipart/form-data form [#2136](https://github.com/labstack/echo/pull/2136)
|
||||
* Fix Timeout middleware write race [#2126](https://github.com/labstack/echo/pull/2126)
|
||||
|
||||
**Enhancements**
|
||||
|
||||
* Recover middleware should not log panic for aborted handler [#2134](https://github.com/labstack/echo/pull/2134)
|
||||
|
||||
|
||||
## v4.7.1 - 2022-03-13
|
||||
|
||||
**Fixes**
|
||||
|
||||
* Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) [#2123](https://github.com/labstack/echo/pull/2123)
|
||||
|
||||
**Enhancements**
|
||||
|
||||
* Remove some unused code [#2116](https://github.com/labstack/echo/pull/2116)
|
||||
|
||||
|
||||
## v4.7.0 - 2022-03-01
|
||||
|
||||
**Enhancements**
|
||||
|
||||
* Add JWT, KeyAuth, CSRF multivalue extractors [#2060](https://github.com/labstack/echo/pull/2060)
|
||||
* Add LogErrorFunc to recover middleware [#2072](https://github.com/labstack/echo/pull/2072)
|
||||
* Add support for HEAD method query params binding [#2027](https://github.com/labstack/echo/pull/2027)
|
||||
* Improve filesystem support with echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS [#2064](https://github.com/labstack/echo/pull/2064)
|
||||
|
||||
**Fixes**
|
||||
|
||||
* Fix X-Real-IP bug, improve tests [#2007](https://github.com/labstack/echo/pull/2007)
|
||||
* Minor syntax fixes [#1994](https://github.com/labstack/echo/pull/1994), [#2102](https://github.com/labstack/echo/pull/2102), [#2102](https://github.com/labstack/echo/pull/2102)
|
||||
|
||||
**General**
|
||||
|
||||
* Add cache-control and connection headers [#2103](https://github.com/labstack/echo/pull/2103)
|
||||
* Add Retry-After header constant [#2078](https://github.com/labstack/echo/pull/2078)
|
||||
* Upgrade `go` directive in `go.mod` to 1.17 [#2049](https://github.com/labstack/echo/pull/2049)
|
||||
* Add Pagoda [#2077](https://github.com/labstack/echo/pull/2077) and Souin [#2069](https://github.com/labstack/echo/pull/2069) to 3rd-party middlewares in README
|
||||
|
||||
## v4.6.3 - 2022-01-10
|
||||
|
||||
**Fixes**
|
||||
|
7
Makefile
7
Makefile
@ -9,7 +9,8 @@ tag:
|
||||
check: lint vet race ## Check project
|
||||
|
||||
init:
|
||||
@go get -u golang.org/x/lint/golint
|
||||
@go install golang.org/x/lint/golint@latest
|
||||
@go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
lint: ## Lint the files
|
||||
@golint -set_exit_status ${PKG_LIST}
|
||||
@ -29,6 +30,6 @@ benchmark: ## Run benchmarks
|
||||
help: ## Display this help screen
|
||||
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
goversion ?= "1.16"
|
||||
test_version: ## Run tests inside Docker with given version (defaults to 1.16 oldest supported). Example: make test_version goversion=1.16
|
||||
goversion ?= "1.17"
|
||||
test_version: ## Run tests inside Docker with given version (defaults to 1.17 oldest supported). Example: make test_version goversion=1.17
|
||||
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"
|
||||
|
197
binder.go
197
binder.go
@ -1,6 +1,8 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@ -52,8 +54,11 @@ import (
|
||||
* time
|
||||
* duration
|
||||
* BindUnmarshaler() interface
|
||||
* TextUnmarshaler() interface
|
||||
* JSONUnmarshaler() interface
|
||||
* UnixTime() - converts unix time (integer) to time.Time
|
||||
* UnixTimeNano() - converts unix time with nano second precision (integer) to time.Time
|
||||
* UnixTimeMilli() - converts unix time with millisecond precision (integer) to time.Time
|
||||
* UnixTimeNano() - converts unix time with nanosecond precision (integer) to time.Time
|
||||
* CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error`
|
||||
*/
|
||||
|
||||
@ -204,7 +209,7 @@ func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []st
|
||||
return b.customFunc(sourceParam, customFunc, false)
|
||||
}
|
||||
|
||||
// MustCustomFunc requires parameter values to exist to be bind with Func. Returns error when value does not exist.
|
||||
// MustCustomFunc requires parameter values to exist to bind with Func. Returns error when value does not exist.
|
||||
func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder {
|
||||
return b.customFunc(sourceParam, customFunc, true)
|
||||
}
|
||||
@ -241,7 +246,7 @@ func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder {
|
||||
return b
|
||||
}
|
||||
|
||||
// MustString requires parameter value to exist to be bind to string variable. Returns error when value does not exist
|
||||
// MustString requires parameter value to exist to bind to string variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder {
|
||||
if b.failFast && b.errors != nil {
|
||||
return b
|
||||
@ -270,7 +275,7 @@ func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder {
|
||||
return b
|
||||
}
|
||||
|
||||
// MustStrings requires parameter values to exist to be bind to slice of string variables. Returns error when value does not exist
|
||||
// MustStrings requires parameter values to exist to bind to slice of string variables. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder {
|
||||
if b.failFast && b.errors != nil {
|
||||
return b
|
||||
@ -302,7 +307,7 @@ func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler)
|
||||
return b
|
||||
}
|
||||
|
||||
// MustBindUnmarshaler requires parameter value to exist to be bind to destination implementing BindUnmarshaler interface.
|
||||
// MustBindUnmarshaler requires parameter value to exist to bind to destination implementing BindUnmarshaler interface.
|
||||
// Returns error when value does not exist
|
||||
func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder {
|
||||
if b.failFast && b.errors != nil {
|
||||
@ -321,13 +326,85 @@ func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshal
|
||||
return b
|
||||
}
|
||||
|
||||
// JSONUnmarshaler binds parameter to destination implementing json.Unmarshaler interface
|
||||
func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder {
|
||||
if b.failFast && b.errors != nil {
|
||||
return b
|
||||
}
|
||||
|
||||
tmp := b.ValueFunc(sourceParam)
|
||||
if tmp == "" {
|
||||
return b
|
||||
}
|
||||
|
||||
if err := dest.UnmarshalJSON([]byte(tmp)); err != nil {
|
||||
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// MustJSONUnmarshaler requires parameter value to exist to bind to destination implementing json.Unmarshaler interface.
|
||||
// Returns error when value does not exist
|
||||
func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder {
|
||||
if b.failFast && b.errors != nil {
|
||||
return b
|
||||
}
|
||||
|
||||
tmp := b.ValueFunc(sourceParam)
|
||||
if tmp == "" {
|
||||
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil))
|
||||
return b
|
||||
}
|
||||
|
||||
if err := dest.UnmarshalJSON([]byte(tmp)); err != nil {
|
||||
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// TextUnmarshaler binds parameter to destination implementing encoding.TextUnmarshaler interface
|
||||
func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder {
|
||||
if b.failFast && b.errors != nil {
|
||||
return b
|
||||
}
|
||||
|
||||
tmp := b.ValueFunc(sourceParam)
|
||||
if tmp == "" {
|
||||
return b
|
||||
}
|
||||
|
||||
if err := dest.UnmarshalText([]byte(tmp)); err != nil {
|
||||
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// MustTextUnmarshaler requires parameter value to exist to bind to destination implementing encoding.TextUnmarshaler interface.
|
||||
// Returns error when value does not exist
|
||||
func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder {
|
||||
if b.failFast && b.errors != nil {
|
||||
return b
|
||||
}
|
||||
|
||||
tmp := b.ValueFunc(sourceParam)
|
||||
if tmp == "" {
|
||||
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil))
|
||||
return b
|
||||
}
|
||||
|
||||
if err := dest.UnmarshalText([]byte(tmp)); err != nil {
|
||||
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// BindWithDelimiter binds parameter to destination by suitable conversion function.
|
||||
// Delimiter is used before conversion to split parameter value to separate values
|
||||
func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder {
|
||||
return b.bindWithDelimiter(sourceParam, dest, delimiter, false)
|
||||
}
|
||||
|
||||
// MustBindWithDelimiter requires parameter value to exist to be bind destination by suitable conversion function.
|
||||
// MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function.
|
||||
// Delimiter is used before conversion to split parameter value to separate values
|
||||
func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder {
|
||||
return b.bindWithDelimiter(sourceParam, dest, delimiter, true)
|
||||
@ -376,7 +453,7 @@ func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder {
|
||||
return b.intValue(sourceParam, dest, 64, false)
|
||||
}
|
||||
|
||||
// MustInt64 requires parameter value to exist to be bind to int64 variable. Returns error when value does not exist
|
||||
// MustInt64 requires parameter value to exist to bind to int64 variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder {
|
||||
return b.intValue(sourceParam, dest, 64, true)
|
||||
}
|
||||
@ -386,7 +463,7 @@ func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder {
|
||||
return b.intValue(sourceParam, dest, 32, false)
|
||||
}
|
||||
|
||||
// MustInt32 requires parameter value to exist to be bind to int32 variable. Returns error when value does not exist
|
||||
// MustInt32 requires parameter value to exist to bind to int32 variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder {
|
||||
return b.intValue(sourceParam, dest, 32, true)
|
||||
}
|
||||
@ -396,7 +473,7 @@ func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder {
|
||||
return b.intValue(sourceParam, dest, 16, false)
|
||||
}
|
||||
|
||||
// MustInt16 requires parameter value to exist to be bind to int16 variable. Returns error when value does not exist
|
||||
// MustInt16 requires parameter value to exist to bind to int16 variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder {
|
||||
return b.intValue(sourceParam, dest, 16, true)
|
||||
}
|
||||
@ -406,7 +483,7 @@ func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder {
|
||||
return b.intValue(sourceParam, dest, 8, false)
|
||||
}
|
||||
|
||||
// MustInt8 requires parameter value to exist to be bind to int8 variable. Returns error when value does not exist
|
||||
// MustInt8 requires parameter value to exist to bind to int8 variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder {
|
||||
return b.intValue(sourceParam, dest, 8, true)
|
||||
}
|
||||
@ -416,7 +493,7 @@ func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder {
|
||||
return b.intValue(sourceParam, dest, 0, false)
|
||||
}
|
||||
|
||||
// MustInt requires parameter value to exist to be bind to int variable. Returns error when value does not exist
|
||||
// MustInt requires parameter value to exist to bind to int variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder {
|
||||
return b.intValue(sourceParam, dest, 0, true)
|
||||
}
|
||||
@ -544,7 +621,7 @@ func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder {
|
||||
return b.intsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustInt64s requires parameter value to exist to be bind to int64 slice variable. Returns error when value does not exist
|
||||
// MustInt64s requires parameter value to exist to bind to int64 slice variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder {
|
||||
return b.intsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -554,7 +631,7 @@ func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder {
|
||||
return b.intsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustInt32s requires parameter value to exist to be bind to int32 slice variable. Returns error when value does not exist
|
||||
// MustInt32s requires parameter value to exist to bind to int32 slice variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder {
|
||||
return b.intsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -564,7 +641,7 @@ func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder {
|
||||
return b.intsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustInt16s requires parameter value to exist to be bind to int16 slice variable. Returns error when value does not exist
|
||||
// MustInt16s requires parameter value to exist to bind to int16 slice variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder {
|
||||
return b.intsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -574,7 +651,7 @@ func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder {
|
||||
return b.intsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustInt8s requires parameter value to exist to be bind to int8 slice variable. Returns error when value does not exist
|
||||
// MustInt8s requires parameter value to exist to bind to int8 slice variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder {
|
||||
return b.intsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -584,7 +661,7 @@ func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder {
|
||||
return b.intsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustInts requires parameter value to exist to be bind to int slice variable. Returns error when value does not exist
|
||||
// MustInts requires parameter value to exist to bind to int slice variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder {
|
||||
return b.intsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -594,7 +671,7 @@ func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 64, false)
|
||||
}
|
||||
|
||||
// MustUint64 requires parameter value to exist to be bind to uint64 variable. Returns error when value does not exist
|
||||
// MustUint64 requires parameter value to exist to bind to uint64 variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 64, true)
|
||||
}
|
||||
@ -604,7 +681,7 @@ func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 32, false)
|
||||
}
|
||||
|
||||
// MustUint32 requires parameter value to exist to be bind to uint32 variable. Returns error when value does not exist
|
||||
// MustUint32 requires parameter value to exist to bind to uint32 variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 32, true)
|
||||
}
|
||||
@ -614,7 +691,7 @@ func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 16, false)
|
||||
}
|
||||
|
||||
// MustUint16 requires parameter value to exist to be bind to uint16 variable. Returns error when value does not exist
|
||||
// MustUint16 requires parameter value to exist to bind to uint16 variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 16, true)
|
||||
}
|
||||
@ -624,7 +701,7 @@ func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 8, false)
|
||||
}
|
||||
|
||||
// MustUint8 requires parameter value to exist to be bind to uint8 variable. Returns error when value does not exist
|
||||
// MustUint8 requires parameter value to exist to bind to uint8 variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 8, true)
|
||||
}
|
||||
@ -634,7 +711,7 @@ func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 8, false)
|
||||
}
|
||||
|
||||
// MustByte requires parameter value to exist to be bind to byte variable. Returns error when value does not exist
|
||||
// MustByte requires parameter value to exist to bind to byte variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 8, true)
|
||||
}
|
||||
@ -644,7 +721,7 @@ func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 0, false)
|
||||
}
|
||||
|
||||
// MustUint requires parameter value to exist to be bind to uint variable. Returns error when value does not exist
|
||||
// MustUint requires parameter value to exist to bind to uint variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder {
|
||||
return b.uintValue(sourceParam, dest, 0, true)
|
||||
}
|
||||
@ -772,7 +849,7 @@ func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder {
|
||||
return b.uintsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustUint64s requires parameter value to exist to be bind to uint64 slice variable. Returns error when value does not exist
|
||||
// MustUint64s requires parameter value to exist to bind to uint64 slice variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder {
|
||||
return b.uintsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -782,7 +859,7 @@ func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder {
|
||||
return b.uintsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustUint32s requires parameter value to exist to be bind to uint32 slice variable. Returns error when value does not exist
|
||||
// MustUint32s requires parameter value to exist to bind to uint32 slice variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder {
|
||||
return b.uintsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -792,7 +869,7 @@ func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder {
|
||||
return b.uintsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustUint16s requires parameter value to exist to be bind to uint16 slice variable. Returns error when value does not exist
|
||||
// MustUint16s requires parameter value to exist to bind to uint16 slice variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder {
|
||||
return b.uintsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -802,7 +879,7 @@ func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder {
|
||||
return b.uintsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustUint8s requires parameter value to exist to be bind to uint8 slice variable. Returns error when value does not exist
|
||||
// MustUint8s requires parameter value to exist to bind to uint8 slice variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder {
|
||||
return b.uintsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -812,7 +889,7 @@ func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder {
|
||||
return b.uintsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustUints requires parameter value to exist to be bind to uint slice variable. Returns error when value does not exist
|
||||
// MustUints requires parameter value to exist to bind to uint slice variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder {
|
||||
return b.uintsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -822,7 +899,7 @@ func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder {
|
||||
return b.boolValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustBool requires parameter value to exist to be bind to bool variable. Returns error when value does not exist
|
||||
// MustBool requires parameter value to exist to bind to bool variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder {
|
||||
return b.boolValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -887,7 +964,7 @@ func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder {
|
||||
return b.boolsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustBools requires parameter values to exist to be bind to slice of bool variables. Returns error when values does not exist
|
||||
// MustBools requires parameter values to exist to bind to slice of bool variables. Returns error when values does not exist
|
||||
func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder {
|
||||
return b.boolsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -897,7 +974,7 @@ func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder {
|
||||
return b.floatValue(sourceParam, dest, 64, false)
|
||||
}
|
||||
|
||||
// MustFloat64 requires parameter value to exist to be bind to float64 variable. Returns error when value does not exist
|
||||
// MustFloat64 requires parameter value to exist to bind to float64 variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder {
|
||||
return b.floatValue(sourceParam, dest, 64, true)
|
||||
}
|
||||
@ -907,7 +984,7 @@ func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder {
|
||||
return b.floatValue(sourceParam, dest, 32, false)
|
||||
}
|
||||
|
||||
// MustFloat32 requires parameter value to exist to be bind to float32 variable. Returns error when value does not exist
|
||||
// MustFloat32 requires parameter value to exist to bind to float32 variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder {
|
||||
return b.floatValue(sourceParam, dest, 32, true)
|
||||
}
|
||||
@ -992,7 +1069,7 @@ func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder
|
||||
return b.floatsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustFloat64s requires parameter values to exist to be bind to slice of float64 variables. Returns error when values does not exist
|
||||
// MustFloat64s requires parameter values to exist to bind to slice of float64 variables. Returns error when values does not exist
|
||||
func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder {
|
||||
return b.floatsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -1002,7 +1079,7 @@ func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder
|
||||
return b.floatsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustFloat32s requires parameter values to exist to be bind to slice of float32 variables. Returns error when values does not exist
|
||||
// MustFloat32s requires parameter values to exist to bind to slice of float32 variables. Returns error when values does not exist
|
||||
func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder {
|
||||
return b.floatsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -1012,7 +1089,7 @@ func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) *
|
||||
return b.time(sourceParam, dest, layout, false)
|
||||
}
|
||||
|
||||
// MustTime requires parameter value to exist to be bind to time.Time variable. Returns error when value does not exist
|
||||
// MustTime requires parameter value to exist to bind to time.Time variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder {
|
||||
return b.time(sourceParam, dest, layout, true)
|
||||
}
|
||||
@ -1043,7 +1120,7 @@ func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string
|
||||
return b.times(sourceParam, dest, layout, false)
|
||||
}
|
||||
|
||||
// MustTimes requires parameter values to exist to be bind to slice of time.Time variables. Returns error when values does not exist
|
||||
// MustTimes requires parameter values to exist to bind to slice of time.Time variables. Returns error when values does not exist
|
||||
func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder {
|
||||
return b.times(sourceParam, dest, layout, true)
|
||||
}
|
||||
@ -1084,7 +1161,7 @@ func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBi
|
||||
return b.duration(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustDuration requires parameter value to exist to be bind to time.Duration variable. Returns error when value does not exist
|
||||
// MustDuration requires parameter value to exist to bind to time.Duration variable. Returns error when value does not exist
|
||||
func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder {
|
||||
return b.duration(sourceParam, dest, true)
|
||||
}
|
||||
@ -1115,7 +1192,7 @@ func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *Valu
|
||||
return b.durationsValue(sourceParam, dest, false)
|
||||
}
|
||||
|
||||
// MustDurations requires parameter values to exist to be bind to slice of time.Duration variables. Returns error when values does not exist
|
||||
// MustDurations requires parameter values to exist to bind to slice of time.Duration variables. Returns error when values does not exist
|
||||
func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder {
|
||||
return b.durationsValue(sourceParam, dest, true)
|
||||
}
|
||||
@ -1161,10 +1238,10 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim
|
||||
// Note:
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, false, false)
|
||||
return b.unixTime(sourceParam, dest, false, time.Second)
|
||||
}
|
||||
|
||||
// MustUnixTime requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding
|
||||
// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding
|
||||
// to the given Unix time). Returns error when value does not exist.
|
||||
//
|
||||
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
|
||||
@ -1172,10 +1249,31 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder
|
||||
// Note:
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, true, false)
|
||||
return b.unixTime(sourceParam, dest, true, time.Second)
|
||||
}
|
||||
|
||||
// UnixTimeNano binds parameter to time.Time variable (in local Time corresponding to the given Unix time in nano second precision).
|
||||
// UnixTimeMilli binds parameter to time.Time variable (in local time corresponding to the given Unix time in millisecond precision).
|
||||
//
|
||||
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
|
||||
//
|
||||
// Note:
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, false, time.Millisecond)
|
||||
}
|
||||
|
||||
// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding
|
||||
// to the given Unix time in millisecond precision). Returns error when value does not exist.
|
||||
//
|
||||
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
|
||||
//
|
||||
// Note:
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, true, time.Millisecond)
|
||||
}
|
||||
|
||||
// UnixTimeNano binds parameter to time.Time variable (in local time corresponding to the given Unix time in nanosecond precision).
|
||||
//
|
||||
// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00
|
||||
// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00
|
||||
@ -1185,10 +1283,10 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
|
||||
func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, false, true)
|
||||
return b.unixTime(sourceParam, dest, false, time.Nanosecond)
|
||||
}
|
||||
|
||||
// MustUnixTimeNano requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding
|
||||
// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding
|
||||
// to the given Unix time value in nano second precision). Returns error when value does not exist.
|
||||
//
|
||||
// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00
|
||||
@ -1199,10 +1297,10 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi
|
||||
// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
|
||||
// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
|
||||
func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
|
||||
return b.unixTime(sourceParam, dest, true, true)
|
||||
return b.unixTime(sourceParam, dest, true, time.Nanosecond)
|
||||
}
|
||||
|
||||
func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, isNano bool) *ValueBinder {
|
||||
func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, precision time.Duration) *ValueBinder {
|
||||
if b.failFast && b.errors != nil {
|
||||
return b
|
||||
}
|
||||
@ -1221,10 +1319,13 @@ func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExi
|
||||
return b
|
||||
}
|
||||
|
||||
if isNano {
|
||||
*dest = time.Unix(0, n)
|
||||
} else {
|
||||
switch precision {
|
||||
case time.Second:
|
||||
*dest = time.Unix(n, 0)
|
||||
case time.Millisecond:
|
||||
*dest = time.Unix(n/1e3, (n%1e3)*1e6) // TODO: time.UnixMilli(n) exists since Go1.17 switch to that when min version allows
|
||||
case time.Nanosecond:
|
||||
*dest = time.Unix(0, n)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
276
binder_test.go
276
binder_test.go
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
@ -55,7 +56,7 @@ func TestBindingError_Error(t *testing.T) {
|
||||
func TestBindingError_ErrorJSON(t *testing.T) {
|
||||
err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error"))
|
||||
|
||||
resp, err := json.Marshal(err)
|
||||
resp, _ := json.Marshal(err)
|
||||
|
||||
assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp))
|
||||
}
|
||||
@ -2188,6 +2189,188 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueBinder_JSONUnmarshaler(t *testing.T) {
|
||||
example := big.NewInt(999)
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenFailFast bool
|
||||
givenBindErrors []error
|
||||
whenURL string
|
||||
whenMust bool
|
||||
expectValue big.Int
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, binds value",
|
||||
whenURL: "/search?param=999¶m=998",
|
||||
expectValue: *example,
|
||||
},
|
||||
{
|
||||
name: "ok, params values empty, value is not changed",
|
||||
whenURL: "/search?nope=1",
|
||||
expectValue: big.Int{},
|
||||
},
|
||||
{
|
||||
name: "nok, previous errors fail fast without binding value",
|
||||
givenFailFast: true,
|
||||
whenURL: "/search?param=1¶m=100",
|
||||
expectValue: big.Int{},
|
||||
expectError: "previous error",
|
||||
},
|
||||
{
|
||||
name: "nok, conversion fails, value is not changed",
|
||||
whenURL: "/search?param=nope¶m=xxx",
|
||||
expectValue: big.Int{},
|
||||
expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
|
||||
},
|
||||
{
|
||||
name: "ok (must), binds value",
|
||||
whenMust: true,
|
||||
whenURL: "/search?param=999¶m=998",
|
||||
expectValue: *example,
|
||||
},
|
||||
{
|
||||
name: "ok (must), params values empty, returns error, value is not changed",
|
||||
whenMust: true,
|
||||
whenURL: "/search?nope=1",
|
||||
expectValue: big.Int{},
|
||||
expectError: "code=400, message=required field value is empty, field=param",
|
||||
},
|
||||
{
|
||||
name: "nok (must), previous errors fail fast without binding value",
|
||||
givenFailFast: true,
|
||||
whenMust: true,
|
||||
whenURL: "/search?param=1¶m=xxx",
|
||||
expectValue: big.Int{},
|
||||
expectError: "previous error",
|
||||
},
|
||||
{
|
||||
name: "nok (must), conversion fails, value is not changed",
|
||||
whenMust: true,
|
||||
whenURL: "/search?param=nope¶m=xxx",
|
||||
expectValue: big.Int{},
|
||||
expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := createTestContext(tc.whenURL, nil, nil)
|
||||
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
|
||||
if tc.givenFailFast {
|
||||
b.errors = []error{errors.New("previous error")}
|
||||
}
|
||||
|
||||
var dest big.Int
|
||||
var err error
|
||||
if tc.whenMust {
|
||||
err = b.MustJSONUnmarshaler("param", &dest).BindError()
|
||||
} else {
|
||||
err = b.JSONUnmarshaler("param", &dest).BindError()
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectValue, dest)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueBinder_TextUnmarshaler(t *testing.T) {
|
||||
example := big.NewInt(999)
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenFailFast bool
|
||||
givenBindErrors []error
|
||||
whenURL string
|
||||
whenMust bool
|
||||
expectValue big.Int
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, binds value",
|
||||
whenURL: "/search?param=999¶m=998",
|
||||
expectValue: *example,
|
||||
},
|
||||
{
|
||||
name: "ok, params values empty, value is not changed",
|
||||
whenURL: "/search?nope=1",
|
||||
expectValue: big.Int{},
|
||||
},
|
||||
{
|
||||
name: "nok, previous errors fail fast without binding value",
|
||||
givenFailFast: true,
|
||||
whenURL: "/search?param=1¶m=100",
|
||||
expectValue: big.Int{},
|
||||
expectError: "previous error",
|
||||
},
|
||||
{
|
||||
name: "nok, conversion fails, value is not changed",
|
||||
whenURL: "/search?param=nope¶m=xxx",
|
||||
expectValue: big.Int{},
|
||||
expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
|
||||
},
|
||||
{
|
||||
name: "ok (must), binds value",
|
||||
whenMust: true,
|
||||
whenURL: "/search?param=999¶m=998",
|
||||
expectValue: *example,
|
||||
},
|
||||
{
|
||||
name: "ok (must), params values empty, returns error, value is not changed",
|
||||
whenMust: true,
|
||||
whenURL: "/search?nope=1",
|
||||
expectValue: big.Int{},
|
||||
expectError: "code=400, message=required field value is empty, field=param",
|
||||
},
|
||||
{
|
||||
name: "nok (must), previous errors fail fast without binding value",
|
||||
givenFailFast: true,
|
||||
whenMust: true,
|
||||
whenURL: "/search?param=1¶m=xxx",
|
||||
expectValue: big.Int{},
|
||||
expectError: "previous error",
|
||||
},
|
||||
{
|
||||
name: "nok (must), conversion fails, value is not changed",
|
||||
whenMust: true,
|
||||
whenURL: "/search?param=nope¶m=xxx",
|
||||
expectValue: big.Int{},
|
||||
expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := createTestContext(tc.whenURL, nil, nil)
|
||||
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
|
||||
if tc.givenFailFast {
|
||||
b.errors = []error{errors.New("previous error")}
|
||||
}
|
||||
|
||||
var dest big.Int
|
||||
var err error
|
||||
if tc.whenMust {
|
||||
err = b.MustTextUnmarshaler("param", &dest).BindError()
|
||||
} else {
|
||||
err = b.TextUnmarshaler("param", &dest).BindError()
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectValue, dest)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueBinder_BindWithDelimiter_types(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
@ -2530,6 +2713,97 @@ func TestValueBinder_UnixTime(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueBinder_UnixTimeMilli(t *testing.T) {
|
||||
exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenFailFast bool
|
||||
givenBindErrors []error
|
||||
whenURL string
|
||||
whenMust bool
|
||||
expectValue time.Time
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, binds value, unix time in milliseconds",
|
||||
whenURL: "/search?param=1647184410140¶m=1647184410199",
|
||||
expectValue: exampleTime,
|
||||
},
|
||||
{
|
||||
name: "ok, params values empty, value is not changed",
|
||||
whenURL: "/search?nope=1",
|
||||
expectValue: time.Time{},
|
||||
},
|
||||
{
|
||||
name: "nok, previous errors fail fast without binding value",
|
||||
givenFailFast: true,
|
||||
whenURL: "/search?param=1¶m=100",
|
||||
expectValue: time.Time{},
|
||||
expectError: "previous error",
|
||||
},
|
||||
{
|
||||
name: "nok, conversion fails, value is not changed",
|
||||
whenURL: "/search?param=nope¶m=100",
|
||||
expectValue: time.Time{},
|
||||
expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
|
||||
},
|
||||
{
|
||||
name: "ok (must), binds value",
|
||||
whenMust: true,
|
||||
whenURL: "/search?param=1647184410140¶m=1647184410199",
|
||||
expectValue: exampleTime,
|
||||
},
|
||||
{
|
||||
name: "ok (must), params values empty, returns error, value is not changed",
|
||||
whenMust: true,
|
||||
whenURL: "/search?nope=1",
|
||||
expectValue: time.Time{},
|
||||
expectError: "code=400, message=required field value is empty, field=param",
|
||||
},
|
||||
{
|
||||
name: "nok (must), previous errors fail fast without binding value",
|
||||
givenFailFast: true,
|
||||
whenMust: true,
|
||||
whenURL: "/search?param=1¶m=100",
|
||||
expectValue: time.Time{},
|
||||
expectError: "previous error",
|
||||
},
|
||||
{
|
||||
name: "nok (must), conversion fails, value is not changed",
|
||||
whenMust: true,
|
||||
whenURL: "/search?param=nope¶m=100",
|
||||
expectValue: time.Time{},
|
||||
expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := createTestContext(tc.whenURL, nil, nil)
|
||||
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
|
||||
if tc.givenFailFast {
|
||||
b.errors = []error{errors.New("previous error")}
|
||||
}
|
||||
|
||||
dest := time.Time{}
|
||||
var err error
|
||||
if tc.whenMust {
|
||||
err = b.MustUnixTimeMilli("param", &dest).BindError()
|
||||
} else {
|
||||
err = b.UnixTimeMilli("param", &dest).BindError()
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano())
|
||||
assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC))
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueBinder_UnixTimeNano(t *testing.T) {
|
||||
exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603
|
||||
exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789
|
||||
|
21
context.go
21
context.go
@ -55,6 +55,16 @@ type Context interface {
|
||||
// PathParam returns path parameter by name.
|
||||
PathParam(name string) string
|
||||
|
||||
// PathParamDefault returns the path parameter or default value for the provided name.
|
||||
//
|
||||
// Notes for DefaultRouter implementation:
|
||||
// Path parameter could be empty for cases like that:
|
||||
// * route `/release-:version/bin` and request URL is `/release-/bin`
|
||||
// * route `/api/:version/image.jpg` and request URL is `/api//image.jpg`
|
||||
// but not when path parameter is last part of route path
|
||||
// * route `/download/file.:ext` will not match request `/download/file.`
|
||||
PathParamDefault(name string, defaultValue string) string
|
||||
|
||||
// PathParams returns path parameter values.
|
||||
PathParams() PathParams
|
||||
|
||||
@ -176,6 +186,9 @@ type Context interface {
|
||||
Redirect(code int, url string) error
|
||||
|
||||
// Echo returns the `Echo` instance.
|
||||
//
|
||||
// WARNING: Remember that Echo public fields and methods are coroutine safe ONLY when you are NOT mutating them
|
||||
// anywhere in your code after Echo server has started.
|
||||
Echo() *Echo
|
||||
}
|
||||
|
||||
@ -379,7 +392,10 @@ func (c *DefaultContext) PathParam(name string) string {
|
||||
return c.pathParams.Get(name, "")
|
||||
}
|
||||
|
||||
// PathParamDefault does not exist as expecting empty path param makes no sense
|
||||
// PathParamDefault returns the path parameter or default value for the provided name.
|
||||
func (c *DefaultContext) PathParamDefault(name, defaultValue string) string {
|
||||
return c.pathParams.Get(name, defaultValue)
|
||||
}
|
||||
|
||||
// PathParams returns path parameter values.
|
||||
func (c *DefaultContext) PathParams() PathParams {
|
||||
@ -406,7 +422,8 @@ func (c *DefaultContext) QueryParam(name string) string {
|
||||
}
|
||||
|
||||
// QueryParamDefault returns the query param or default value for the provided name.
|
||||
// Note: QueryParamDefault does not distinguish if form had no value by that name or value was empty string
|
||||
// Note: QueryParamDefault does not distinguish if query had no value by that name or value was empty string
|
||||
// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamDefault("search", "1")`
|
||||
func (c *DefaultContext) QueryParamDefault(name, defaultValue string) string {
|
||||
value := c.QueryParam(name)
|
||||
if value == "" {
|
||||
|
286
context_test.go
286
context_test.go
@ -448,22 +448,142 @@ func TestContextCookie(t *testing.T) {
|
||||
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
|
||||
}
|
||||
|
||||
func TestContextPathParam(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c := e.NewContext(req, nil).(*DefaultContext)
|
||||
|
||||
params := &PathParams{
|
||||
{Name: "uid", Value: "101"},
|
||||
{Name: "fid", Value: "501"},
|
||||
func TestContext_PathParams(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
given *PathParams
|
||||
expect PathParams
|
||||
}{
|
||||
{
|
||||
name: "param exists",
|
||||
given: &PathParams{
|
||||
{Name: "uid", Value: "101"},
|
||||
{Name: "fid", Value: "501"},
|
||||
},
|
||||
expect: PathParams{
|
||||
{Name: "uid", Value: "101"},
|
||||
{Name: "fid", Value: "501"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "params is empty",
|
||||
given: &PathParams{},
|
||||
expect: PathParams{},
|
||||
},
|
||||
}
|
||||
// ParamNames
|
||||
c.pathParams = params
|
||||
assert.EqualValues(t, *params, c.PathParams())
|
||||
|
||||
// Param
|
||||
assert.Equal(t, "501", c.PathParam("fid"))
|
||||
assert.Equal(t, "", c.PathParam("undefined"))
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c := e.NewContext(req, nil)
|
||||
|
||||
c.(RoutableContext).SetRawPathParams(tc.given)
|
||||
|
||||
assert.EqualValues(t, tc.expect, c.PathParams())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext_PathParam(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
given *PathParams
|
||||
whenParamName string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "param exists",
|
||||
given: &PathParams{
|
||||
{Name: "uid", Value: "101"},
|
||||
{Name: "fid", Value: "501"},
|
||||
},
|
||||
whenParamName: "uid",
|
||||
expect: "101",
|
||||
},
|
||||
{
|
||||
name: "multiple same param values exists - return first",
|
||||
given: &PathParams{
|
||||
{Name: "uid", Value: "101"},
|
||||
{Name: "uid", Value: "202"},
|
||||
{Name: "fid", Value: "501"},
|
||||
},
|
||||
whenParamName: "uid",
|
||||
expect: "101",
|
||||
},
|
||||
{
|
||||
name: "param does not exists",
|
||||
given: &PathParams{
|
||||
{Name: "uid", Value: "101"},
|
||||
},
|
||||
whenParamName: "nope",
|
||||
expect: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c := e.NewContext(req, nil)
|
||||
|
||||
c.(RoutableContext).SetRawPathParams(tc.given)
|
||||
|
||||
assert.EqualValues(t, tc.expect, c.PathParam(tc.whenParamName))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext_PathParamDefault(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
given *PathParams
|
||||
whenParamName string
|
||||
whenDefaultValue string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "param exists",
|
||||
given: &PathParams{
|
||||
{Name: "uid", Value: "101"},
|
||||
{Name: "fid", Value: "501"},
|
||||
},
|
||||
whenParamName: "uid",
|
||||
whenDefaultValue: "999",
|
||||
expect: "101",
|
||||
},
|
||||
{
|
||||
name: "param exists and is empty",
|
||||
given: &PathParams{
|
||||
{Name: "uid", Value: ""},
|
||||
{Name: "fid", Value: "501"},
|
||||
},
|
||||
whenParamName: "uid",
|
||||
whenDefaultValue: "999",
|
||||
expect: "", // <-- this is different from QueryParamDefault behaviour
|
||||
},
|
||||
{
|
||||
name: "param does not exists",
|
||||
given: &PathParams{
|
||||
{Name: "uid", Value: "101"},
|
||||
},
|
||||
whenParamName: "nope",
|
||||
whenDefaultValue: "999",
|
||||
expect: "999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c := e.NewContext(req, nil)
|
||||
|
||||
c.(RoutableContext).SetRawPathParams(tc.given)
|
||||
|
||||
assert.EqualValues(t, tc.expect, c.PathParamDefault(tc.whenParamName, tc.whenDefaultValue))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextGetAndSetParam(t *testing.T) {
|
||||
@ -568,27 +688,129 @@ func TestContextFormValue(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestContextQueryParam(t *testing.T) {
|
||||
q := make(url.Values)
|
||||
q.Set("name", "Jon Snow")
|
||||
q.Set("email", "jon@labstack.com")
|
||||
req := httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil)
|
||||
e := New()
|
||||
c := e.NewContext(req, nil)
|
||||
func TestContext_QueryParams(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenURL string
|
||||
expect url.Values
|
||||
}{
|
||||
{
|
||||
name: "multiple values in url",
|
||||
givenURL: "/?test=1&test=2&email=jon%40labstack.com",
|
||||
expect: url.Values{
|
||||
"test": []string{"1", "2"},
|
||||
"email": []string{"jon@labstack.com"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single value in url",
|
||||
givenURL: "/?nope=1",
|
||||
expect: url.Values{
|
||||
"nope": []string{"1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no query params in url",
|
||||
givenURL: "/?",
|
||||
expect: url.Values{},
|
||||
},
|
||||
}
|
||||
|
||||
// QueryParam
|
||||
assert.Equal(t, "Jon Snow", c.QueryParam("name"))
|
||||
assert.Equal(t, "jon@labstack.com", c.QueryParam("email"))
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
|
||||
e := New()
|
||||
c := e.NewContext(req, nil)
|
||||
|
||||
// QueryParamDefault
|
||||
assert.Equal(t, "Jon Snow", c.QueryParamDefault("name", "nope"))
|
||||
assert.Equal(t, "default", c.QueryParamDefault("missing", "default"))
|
||||
assert.Equal(t, tc.expect, c.QueryParams())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// QueryParams
|
||||
assert.Equal(t, url.Values{
|
||||
"name": []string{"Jon Snow"},
|
||||
"email": []string{"jon@labstack.com"},
|
||||
}, c.QueryParams())
|
||||
func TestContext_QueryParam(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenURL string
|
||||
whenParamName string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "value exists in url",
|
||||
givenURL: "/?test=1",
|
||||
whenParamName: "test",
|
||||
expect: "1",
|
||||
},
|
||||
{
|
||||
name: "multiple values exists in url",
|
||||
givenURL: "/?test=9&test=8",
|
||||
whenParamName: "test",
|
||||
expect: "9", // <-- first value in returned
|
||||
},
|
||||
{
|
||||
name: "value does not exists in url",
|
||||
givenURL: "/?nope=1",
|
||||
whenParamName: "test",
|
||||
expect: "",
|
||||
},
|
||||
{
|
||||
name: "value is empty in url",
|
||||
givenURL: "/?test=",
|
||||
whenParamName: "test",
|
||||
expect: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
|
||||
e := New()
|
||||
c := e.NewContext(req, nil)
|
||||
|
||||
assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext_QueryParamDefault(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenURL string
|
||||
whenParamName string
|
||||
whenDefaultValue string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "value exists in url",
|
||||
givenURL: "/?test=1",
|
||||
whenParamName: "test",
|
||||
whenDefaultValue: "999",
|
||||
expect: "1",
|
||||
},
|
||||
{
|
||||
name: "value does not exists in url",
|
||||
givenURL: "/?nope=1",
|
||||
whenParamName: "test",
|
||||
whenDefaultValue: "999",
|
||||
expect: "999",
|
||||
},
|
||||
{
|
||||
name: "value is empty in url",
|
||||
givenURL: "/?test=",
|
||||
whenParamName: "test",
|
||||
whenDefaultValue: "999",
|
||||
expect: "999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
|
||||
e := New()
|
||||
c := e.NewContext(req, nil)
|
||||
|
||||
assert.Equal(t, tc.expect, c.QueryParamDefault(tc.whenParamName, tc.whenDefaultValue))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextFormFile(t *testing.T) {
|
||||
|
58
echo.go
58
echo.go
@ -49,12 +49,15 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Echo is the top-level framework instance.
|
||||
// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics.
|
||||
//
|
||||
// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics. This is very likely
|
||||
// to happen when you access Echo instances through Context.Echo() method.
|
||||
type Echo struct {
|
||||
// premiddleware are middlewares that are run for every request before routing is done
|
||||
premiddleware []MiddlewareFunc
|
||||
@ -66,8 +69,8 @@ type Echo struct {
|
||||
routerCreator func(e *Echo) Router
|
||||
|
||||
contextPool sync.Pool
|
||||
// contextPathParamAllocSize holds maximum parameter count for all added routes. This is necessary info for context
|
||||
// creation time so we can allocate path parameter values slice.
|
||||
// contextPathParamAllocSize holds maximum parameter count for all added routes. This is necessary info at context
|
||||
// creation moment so we can allocate path parameter values slice with correct size.
|
||||
contextPathParamAllocSize int
|
||||
|
||||
// NewContextFunc allows using custom context implementations, instead of default *echo.context
|
||||
@ -150,6 +153,8 @@ const (
|
||||
PROPFIND = "PROPFIND"
|
||||
// REPORT Method can be used to get information about a resource, see rfc 3253
|
||||
REPORT = "REPORT"
|
||||
// RouteNotFound is special method type for routes handling "route not found" (404) cases
|
||||
RouteNotFound = "echo_route_not_found"
|
||||
)
|
||||
|
||||
// Headers
|
||||
@ -181,12 +186,14 @@ const (
|
||||
HeaderXForwardedSsl = "X-Forwarded-Ssl"
|
||||
HeaderXUrlScheme = "X-Url-Scheme"
|
||||
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
|
||||
HeaderXRealIP = "X-Real-IP"
|
||||
HeaderXRequestID = "X-Request-ID"
|
||||
HeaderXCorrelationID = "X-Correlation-ID"
|
||||
HeaderXRealIP = "X-Real-Ip"
|
||||
HeaderXRequestID = "X-Request-Id"
|
||||
HeaderXCorrelationID = "X-Correlation-Id"
|
||||
HeaderXRequestedWith = "X-Requested-With"
|
||||
HeaderServer = "Server"
|
||||
HeaderOrigin = "Origin"
|
||||
HeaderCacheControl = "Cache-Control"
|
||||
HeaderConnection = "Connection"
|
||||
|
||||
// Access control
|
||||
HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
|
||||
@ -403,6 +410,16 @@ func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo
|
||||
return e.Add(http.MethodTrace, path, h, m...)
|
||||
}
|
||||
|
||||
// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases)
|
||||
// for current request URL.
|
||||
// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with
|
||||
// wildcard/match-any character (`/*`, `/download/*` etc).
|
||||
//
|
||||
// Example: `e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })`
|
||||
func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
|
||||
return e.Add(RouteNotFound, path, h, m...)
|
||||
}
|
||||
|
||||
// Any registers a new route for all supported HTTP methods and path with matching handler
|
||||
// in the router with optional route-level middleware. Panics on error.
|
||||
func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
|
||||
@ -707,20 +724,26 @@ func newDefaultFS() *defaultFS {
|
||||
dir, _ := os.Getwd()
|
||||
return &defaultFS{
|
||||
prefix: dir,
|
||||
fs: os.DirFS(dir),
|
||||
fs: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (fs defaultFS) Open(name string) (fs.File, error) {
|
||||
if fs.fs == nil {
|
||||
return os.Open(name)
|
||||
}
|
||||
return fs.fs.Open(name)
|
||||
}
|
||||
|
||||
func subFS(currentFs fs.FS, root string) (fs.FS, error) {
|
||||
root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows
|
||||
if dFS, ok := currentFs.(*defaultFS); ok {
|
||||
// we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS to
|
||||
// allow cases when root is given as `../somepath` which is not valid for fs.FS
|
||||
root = filepath.Join(dFS.prefix, root)
|
||||
// we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS.
|
||||
// fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we
|
||||
// were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs
|
||||
if isRelativePath(root) {
|
||||
root = filepath.Join(dFS.prefix, root)
|
||||
}
|
||||
return &defaultFS{
|
||||
prefix: root,
|
||||
fs: os.DirFS(root),
|
||||
@ -729,6 +752,21 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) {
|
||||
return fs.Sub(currentFs, root)
|
||||
}
|
||||
|
||||
func isRelativePath(path string) bool {
|
||||
if path == "" {
|
||||
return true
|
||||
}
|
||||
if path[0] == '/' {
|
||||
return false
|
||||
}
|
||||
if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 {
|
||||
// https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names
|
||||
// https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// MustSubFS creates sub FS from current filesystem or panic on failure.
|
||||
// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules.
|
||||
//
|
||||
|
120
echo_test.go
120
echo_test.go
@ -336,16 +336,54 @@ func TestEchoStaticRedirectIndex(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEchoFile(t *testing.T) {
|
||||
e := New()
|
||||
ri := e.File("/walle", "_fixture/images/walle.png")
|
||||
assert.Equal(t, http.MethodGet, ri.Method())
|
||||
assert.Equal(t, "/walle", ri.Path())
|
||||
assert.Equal(t, "GET:/walle", ri.Name())
|
||||
assert.Nil(t, ri.Params())
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenPath string
|
||||
givenFile string
|
||||
whenPath string
|
||||
expectCode int
|
||||
expectStartsWith string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
givenPath: "/walle",
|
||||
givenFile: "_fixture/images/walle.png",
|
||||
whenPath: "/walle",
|
||||
expectCode: http.StatusOK,
|
||||
expectStartsWith: string([]byte{0x89, 0x50, 0x4e}),
|
||||
},
|
||||
{
|
||||
name: "ok with relative path",
|
||||
givenPath: "/",
|
||||
givenFile: "./go.mod",
|
||||
whenPath: "/",
|
||||
expectCode: http.StatusOK,
|
||||
expectStartsWith: "module github.com/labstack/echo/v",
|
||||
},
|
||||
{
|
||||
name: "nok file does not exist",
|
||||
givenPath: "/",
|
||||
givenFile: "./this-file-does-not-exist",
|
||||
whenPath: "/",
|
||||
expectCode: http.StatusNotFound,
|
||||
expectStartsWith: "{\"message\":\"Not Found\"}\n",
|
||||
},
|
||||
}
|
||||
|
||||
c, b := request(http.MethodGet, "/walle", e)
|
||||
assert.Equal(t, http.StatusOK, c)
|
||||
assert.NotEmpty(t, b)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New() // we are using echo.defaultFS instance
|
||||
e.File(tc.givenPath, tc.givenFile)
|
||||
|
||||
c, b := request(http.MethodGet, tc.whenPath, e)
|
||||
assert.Equal(t, tc.expectCode, c)
|
||||
|
||||
if len(b) > len(tc.expectStartsWith) {
|
||||
b = b[:len(tc.expectStartsWith)]
|
||||
}
|
||||
assert.Equal(t, tc.expectStartsWith, b)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEchoMiddleware(t *testing.T) {
|
||||
@ -880,6 +918,70 @@ func TestEchoGroup(t *testing.T) {
|
||||
assert.Equal(t, "023", buf.String())
|
||||
}
|
||||
|
||||
func TestEcho_RouteNotFound(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenURL string
|
||||
expectRoute interface{}
|
||||
expectCode int
|
||||
}{
|
||||
{
|
||||
name: "404, route to static not found handler /a/c/xx",
|
||||
whenURL: "/a/c/xx",
|
||||
expectRoute: "GET /a/c/xx",
|
||||
expectCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "404, route to path param not found handler /a/:file",
|
||||
whenURL: "/a/echo.exe",
|
||||
expectRoute: "GET /a/:file",
|
||||
expectCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "404, route to any not found handler /*",
|
||||
whenURL: "/b/echo.exe",
|
||||
expectRoute: "GET /*",
|
||||
expectCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "200, route /a/c/df to /a/c/df",
|
||||
whenURL: "/a/c/df",
|
||||
expectRoute: "GET /a/c/df",
|
||||
expectCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
okHandler := func(c Context) error {
|
||||
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
|
||||
}
|
||||
notFoundHandler := func(c Context) error {
|
||||
return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
|
||||
}
|
||||
|
||||
e.GET("/", okHandler)
|
||||
e.GET("/a/c/df", okHandler)
|
||||
e.GET("/a/b*", okHandler)
|
||||
e.PUT("/*", okHandler)
|
||||
|
||||
e.RouteNotFound("/a/c/xx", notFoundHandler) // static
|
||||
e.RouteNotFound("/a/:file", notFoundHandler) // param
|
||||
e.RouteNotFound("/*", notFoundHandler) // any
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, tc.expectCode, rec.Code)
|
||||
assert.Equal(t, tc.expectRoute, rec.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEchoNotFound(t *testing.T) {
|
||||
e := New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/files", nil)
|
||||
|
7
group.go
7
group.go
@ -157,6 +157,13 @@ func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo
|
||||
return g.Add(http.MethodGet, path, handler, middleware...)
|
||||
}
|
||||
|
||||
// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group.
|
||||
//
|
||||
// Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })`
|
||||
func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
|
||||
return g.Add(RouteNotFound, path, h, m...)
|
||||
}
|
||||
|
||||
// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
|
||||
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
|
||||
ri, err := g.AddRoute(Route{
|
||||
|
@ -303,6 +303,71 @@ func TestGroup_TRACE(t *testing.T) {
|
||||
assert.Equal(t, `OK`, body)
|
||||
}
|
||||
|
||||
func TestGroup_RouteNotFound(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenURL string
|
||||
expectRoute interface{}
|
||||
expectCode int
|
||||
}{
|
||||
{
|
||||
name: "404, route to static not found handler /group/a/c/xx",
|
||||
whenURL: "/group/a/c/xx",
|
||||
expectRoute: "GET /group/a/c/xx",
|
||||
expectCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "404, route to path param not found handler /group/a/:file",
|
||||
whenURL: "/group/a/echo.exe",
|
||||
expectRoute: "GET /group/a/:file",
|
||||
expectCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "404, route to any not found handler /group/*",
|
||||
whenURL: "/group/b/echo.exe",
|
||||
expectRoute: "GET /group/*",
|
||||
expectCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "200, route /group/a/c/df to /group/a/c/df",
|
||||
whenURL: "/group/a/c/df",
|
||||
expectRoute: "GET /group/a/c/df",
|
||||
expectCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
g := e.Group("/group")
|
||||
|
||||
okHandler := func(c Context) error {
|
||||
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
|
||||
}
|
||||
notFoundHandler := func(c Context) error {
|
||||
return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
|
||||
}
|
||||
|
||||
g.GET("/", okHandler)
|
||||
g.GET("/a/c/df", okHandler)
|
||||
g.GET("/a/b*", okHandler)
|
||||
g.PUT("/*", okHandler)
|
||||
|
||||
g.RouteNotFound("/a/c/xx", notFoundHandler) // static
|
||||
g.RouteNotFound("/a/:file", notFoundHandler) // param
|
||||
g.RouteNotFound("/*", notFoundHandler) // any
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, tc.expectCode, rec.Code)
|
||||
assert.Equal(t, tc.expectRoute, rec.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_Any(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
|
142
ip.go
142
ip.go
@ -6,6 +6,130 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
/**
|
||||
By: https://github.com/tmshn (See: https://github.com/labstack/echo/pull/1478 , https://github.com/labstack/echox/pull/134 )
|
||||
Source: https://echo.labstack.com/guide/ip-address/
|
||||
|
||||
IP address plays fundamental role in HTTP; it's used for access control, auditing, geo-based access analysis and more.
|
||||
Echo provides handy method [`Context#RealIP()`](https://godoc.org/github.com/labstack/echo#Context) for that.
|
||||
|
||||
However, it is not trivial to retrieve the _real_ IP address from requests especially when you put L7 proxies before the application.
|
||||
In such situation, _real_ IP needs to be relayed on HTTP layer from proxies to your app, but you must not trust HTTP headers unconditionally.
|
||||
Otherwise, you might give someone a chance of deceiving you. **A security risk!**
|
||||
|
||||
To retrieve IP address reliably/securely, you must let your application be aware of the entire architecture of your infrastructure.
|
||||
In Echo, this can be done by configuring `Echo#IPExtractor` appropriately.
|
||||
This guides show you why and how.
|
||||
|
||||
> Note: if you dont' set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice.
|
||||
|
||||
Let's start from two questions to know the right direction:
|
||||
|
||||
1. Do you put any HTTP (L7) proxy in front of the application?
|
||||
- It includes both cloud solutions (such as AWS ALB or GCP HTTP LB) and OSS ones (such as Nginx, Envoy or Istio ingress gateway).
|
||||
2. If yes, what HTTP header do your proxies use to pass client IP to the application?
|
||||
|
||||
## Case 1. With no proxy
|
||||
|
||||
If you put no proxy (e.g.: directory facing to the internet), all you need to (and have to) see is IP address from network layer.
|
||||
Any HTTP header is untrustable because the clients have full control what headers to be set.
|
||||
|
||||
In this case, use `echo.ExtractIPDirect()`.
|
||||
|
||||
```go
|
||||
e.IPExtractor = echo.ExtractIPDirect()
|
||||
```
|
||||
|
||||
## Case 2. With proxies using `X-Forwarded-For` header
|
||||
|
||||
[`X-Forwared-For` (XFF)](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) is the popular header
|
||||
to relay clients' IP addresses.
|
||||
At each hop on the proxies, they append the request IP address at the end of the header.
|
||||
|
||||
Following example diagram illustrates this behavior.
|
||||
|
||||
```text
|
||||
┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
|
||||
│ "Origin" │───────────>│ Proxy 1 │───────────>│ Proxy 2 │───────────>│ Your app │
|
||||
│ (IP: a) │ │ (IP: b) │ │ (IP: c) │ │ │
|
||||
└──────────┘ └──────────┘ └──────────┘ └──────────┘
|
||||
|
||||
Case 1.
|
||||
XFF: "" "a" "a, b"
|
||||
~~~~~~
|
||||
Case 2.
|
||||
XFF: "x" "x, a" "x, a, b"
|
||||
~~~~~~~~~
|
||||
↑ What your app will see
|
||||
```
|
||||
|
||||
In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is
|
||||
configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructre".
|
||||
In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`.
|
||||
|
||||
In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`.
|
||||
|
||||
```go
|
||||
e.IPExtractor = echo.ExtractIPFromXFFHeader()
|
||||
```
|
||||
|
||||
By default, it trusts internal IP addresses (loopback, link-local unicast, private-use and unique local address
|
||||
from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and
|
||||
[RFC4193](https://tools.ietf.org/html/rfc4193)).
|
||||
To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s.
|
||||
|
||||
E.g.:
|
||||
|
||||
```go
|
||||
e.IPExtractor = echo.ExtractIPFromXFFHeader(
|
||||
TrustLinkLocal(false),
|
||||
TrustIPRanges(lbIPRange),
|
||||
)
|
||||
```
|
||||
|
||||
- Ref: https://godoc.org/github.com/labstack/echo#TrustOption
|
||||
|
||||
## Case 3. With proxies using `X-Real-IP` header
|
||||
|
||||
`X-Real-IP` is another HTTP header to relay clients' IP addresses, but it carries only one address unlike XFF.
|
||||
|
||||
If your proxies set this header, use `ExtractIPFromRealIPHeader(...TrustOption)`.
|
||||
|
||||
```go
|
||||
e.IPExtractor = echo.ExtractIPFromRealIPHeader()
|
||||
```
|
||||
|
||||
Again, it trusts internal IP addresses by default (loopback, link-local unicast, private-use and unique local address
|
||||
from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and
|
||||
[RFC4193](https://tools.ietf.org/html/rfc4193)).
|
||||
To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s.
|
||||
|
||||
- Ref: https://godoc.org/github.com/labstack/echo#TrustOption
|
||||
|
||||
> **Never forget** to configure the outermost proxy (i.e.; at the edge of your infrastructure) **not to pass through incoming headers**.
|
||||
> Otherwise there is a chance of fraud, as it is what clients can control.
|
||||
|
||||
## About default behavior
|
||||
|
||||
In default behavior, Echo sees all of first XFF header, X-Real-IP header and IP from network layer.
|
||||
|
||||
As you might already notice, after reading this article, this is not good.
|
||||
Sole reason this is default is just backward compatibility.
|
||||
|
||||
## Private IP ranges
|
||||
|
||||
See: https://en.wikipedia.org/wiki/Private_network
|
||||
|
||||
Private IPv4 address ranges (RFC 1918):
|
||||
* 10.0.0.0 – 10.255.255.255 (24-bit block)
|
||||
* 172.16.0.0 – 172.31.255.255 (20-bit block)
|
||||
* 192.168.0.0 – 192.168.255.255 (16-bit block)
|
||||
|
||||
Private IPv6 address ranges:
|
||||
* fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA)
|
||||
|
||||
*/
|
||||
|
||||
type ipChecker struct {
|
||||
trustLoopback bool
|
||||
trustLinkLocal bool
|
||||
@ -52,6 +176,7 @@ func newIPChecker(configs []TrustOption) *ipChecker {
|
||||
return checker
|
||||
}
|
||||
|
||||
// Go1.16+ added `ip.IsPrivate()` but until that use this implementation
|
||||
func isPrivateIPRange(ip net.IP) bool {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return ip4[0] == 10 ||
|
||||
@ -87,10 +212,12 @@ type IPExtractor func(*http.Request) string
|
||||
// ExtractIPDirect extracts IP address using actual IP address.
|
||||
// Use this if your server faces to internet directory (i.e.: uses no proxy).
|
||||
func ExtractIPDirect() IPExtractor {
|
||||
return func(req *http.Request) string {
|
||||
ra, _, _ := net.SplitHostPort(req.RemoteAddr)
|
||||
return ra
|
||||
}
|
||||
return extractIP
|
||||
}
|
||||
|
||||
func extractIP(req *http.Request) string {
|
||||
ra, _, _ := net.SplitHostPort(req.RemoteAddr)
|
||||
return ra
|
||||
}
|
||||
|
||||
// ExtractIPFromRealIPHeader extracts IP address using x-real-ip header.
|
||||
@ -98,14 +225,13 @@ func ExtractIPDirect() IPExtractor {
|
||||
func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor {
|
||||
checker := newIPChecker(options)
|
||||
return func(req *http.Request) string {
|
||||
directIP := ExtractIPDirect()(req)
|
||||
realIP := req.Header.Get(HeaderXRealIP)
|
||||
if realIP != "" {
|
||||
if ip := net.ParseIP(directIP); ip != nil && checker.trust(ip) {
|
||||
if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) {
|
||||
return realIP
|
||||
}
|
||||
}
|
||||
return directIP
|
||||
return extractIP(req)
|
||||
}
|
||||
}
|
||||
|
||||
@ -115,7 +241,7 @@ func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor {
|
||||
func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor {
|
||||
checker := newIPChecker(options)
|
||||
return func(req *http.Request) string {
|
||||
directIP := ExtractIPDirect()(req)
|
||||
directIP := extractIP(req)
|
||||
xffs := req.Header[HeaderXForwardedFor]
|
||||
if len(xffs) == 0 {
|
||||
return directIP
|
||||
|
803
ip_test.go
803
ip_test.go
@ -1,235 +1,606 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
testify "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
// For RemoteAddr
|
||||
ipForRemoteAddrLoopback = "127.0.0.1" // From 127.0.0.0/8
|
||||
sampleRemoteAddrLoopback = ipForRemoteAddrLoopback + ":8080"
|
||||
ipForRemoteAddrExternal = "203.0.113.1"
|
||||
sampleRemoteAddrExternal = ipForRemoteAddrExternal + ":8080"
|
||||
// For x-real-ip
|
||||
ipForRealIP = "203.0.113.10"
|
||||
// For XFF
|
||||
ipForXFF1LinkLocal = "169.254.0.101" // From 169.254.0.0/16
|
||||
ipForXFF2Private = "192.168.0.102" // From 192.168.0.0/16
|
||||
ipForXFF3External = "2001:db8::103"
|
||||
ipForXFF4Private = "fc00::104" // From fc00::/7
|
||||
ipForXFF5External = "198.51.100.105"
|
||||
ipForXFF6External = "192.0.2.106"
|
||||
ipForXFFBroken = "this.is.broken.lol"
|
||||
// keys for test cases
|
||||
ipTestReqKeyNoHeader = "no header"
|
||||
ipTestReqKeyRealIPExternal = "x-real-ip; remote addr external"
|
||||
ipTestReqKeyRealIPInternal = "x-real-ip; remote addr internal"
|
||||
ipTestReqKeyRealIPAndXFFExternal = "x-real-ip and xff; remote addr external"
|
||||
ipTestReqKeyRealIPAndXFFInternal = "x-real-ip and xff; remote addr internal"
|
||||
ipTestReqKeyXFFExternal = "xff; remote addr external"
|
||||
ipTestReqKeyXFFInternal = "xff; remote addr internal"
|
||||
ipTestReqKeyBrokenXFF = "broken xff"
|
||||
)
|
||||
|
||||
var (
|
||||
sampleXFF = strings.Join([]string{
|
||||
ipForXFF6External, ipForXFF5External, ipForXFF4Private, ipForXFF3External, ipForXFF2Private, ipForXFF1LinkLocal,
|
||||
}, ", ")
|
||||
|
||||
requests = map[string]*http.Request{
|
||||
ipTestReqKeyNoHeader: &http.Request{
|
||||
RemoteAddr: sampleRemoteAddrExternal,
|
||||
},
|
||||
ipTestReqKeyRealIPExternal: &http.Request{
|
||||
Header: http.Header{
|
||||
"X-Real-Ip": []string{ipForRealIP},
|
||||
},
|
||||
RemoteAddr: sampleRemoteAddrExternal,
|
||||
},
|
||||
ipTestReqKeyRealIPInternal: &http.Request{
|
||||
Header: http.Header{
|
||||
"X-Real-Ip": []string{ipForRealIP},
|
||||
},
|
||||
RemoteAddr: sampleRemoteAddrLoopback,
|
||||
},
|
||||
ipTestReqKeyRealIPAndXFFExternal: &http.Request{
|
||||
Header: http.Header{
|
||||
"X-Real-Ip": []string{ipForRealIP},
|
||||
HeaderXForwardedFor: []string{sampleXFF},
|
||||
},
|
||||
RemoteAddr: sampleRemoteAddrExternal,
|
||||
},
|
||||
ipTestReqKeyRealIPAndXFFInternal: &http.Request{
|
||||
Header: http.Header{
|
||||
"X-Real-Ip": []string{ipForRealIP},
|
||||
HeaderXForwardedFor: []string{sampleXFF},
|
||||
},
|
||||
RemoteAddr: sampleRemoteAddrLoopback,
|
||||
},
|
||||
ipTestReqKeyXFFExternal: &http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXForwardedFor: []string{sampleXFF},
|
||||
},
|
||||
RemoteAddr: sampleRemoteAddrExternal,
|
||||
},
|
||||
ipTestReqKeyXFFInternal: &http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXForwardedFor: []string{sampleXFF},
|
||||
},
|
||||
RemoteAddr: sampleRemoteAddrLoopback,
|
||||
},
|
||||
ipTestReqKeyBrokenXFF: &http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXForwardedFor: []string{ipForXFFBroken + ", " + ipForXFF1LinkLocal},
|
||||
},
|
||||
RemoteAddr: sampleRemoteAddrLoopback,
|
||||
},
|
||||
func mustParseCIDR(s string) *net.IPNet {
|
||||
_, IPNet, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
)
|
||||
return IPNet
|
||||
}
|
||||
|
||||
func TestExtractIP(t *testing.T) {
|
||||
_, ipv4AllRange, _ := net.ParseCIDR("0.0.0.0/0")
|
||||
_, ipv6AllRange, _ := net.ParseCIDR("::/0")
|
||||
_, ipForXFF3ExternalRange, _ := net.ParseCIDR(ipForXFF3External + "/48")
|
||||
_, ipForRemoteAddrExternalRange, _ := net.ParseCIDR(ipForRemoteAddrExternal + "/24")
|
||||
|
||||
tests := map[string]*struct {
|
||||
extractor IPExtractor
|
||||
expectedIPs map[string]string
|
||||
func TestIPChecker_TrustOption(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenOptions []TrustOption
|
||||
whenIP string
|
||||
expect bool
|
||||
}{
|
||||
"ExtractIPDirect": {
|
||||
ExtractIPDirect(),
|
||||
map[string]string{
|
||||
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
|
||||
{
|
||||
name: "ip is within trust range, trusts additional private IPV6 network",
|
||||
givenOptions: []TrustOption{
|
||||
TrustLoopback(false),
|
||||
TrustLinkLocal(false),
|
||||
TrustPrivateNet(false),
|
||||
// this is private IPv6 ip
|
||||
// CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48
|
||||
// Address: 2001:0db8:0000:0000:0000:0000:0000:0103
|
||||
// Range start: 2001:0db8:0000:0000:0000:0000:0000:0000
|
||||
// Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff
|
||||
TrustIPRange(mustParseCIDR("2001:db8::103/48")),
|
||||
},
|
||||
whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
|
||||
expect: true,
|
||||
},
|
||||
"ExtractIPFromRealIPHeader(default)": {
|
||||
ExtractIPFromRealIPHeader(),
|
||||
map[string]string{
|
||||
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPInternal: ipForRealIP,
|
||||
ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPAndXFFInternal: ipForRealIP,
|
||||
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
|
||||
},
|
||||
},
|
||||
"ExtractIPFromRealIPHeader(trust only direct-facing proxy)": {
|
||||
ExtractIPFromRealIPHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)),
|
||||
map[string]string{
|
||||
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPExternal: ipForRealIP,
|
||||
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyRealIPAndXFFExternal: ipForRealIP,
|
||||
ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
|
||||
},
|
||||
},
|
||||
"ExtractIPFromRealIPHeader(trust direct-facing proxy)": {
|
||||
ExtractIPFromRealIPHeader(TrustIPRange(ipForRemoteAddrExternalRange)),
|
||||
map[string]string{
|
||||
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPExternal: ipForRealIP,
|
||||
ipTestReqKeyRealIPInternal: ipForRealIP,
|
||||
ipTestReqKeyRealIPAndXFFExternal: ipForRealIP,
|
||||
ipTestReqKeyRealIPAndXFFInternal: ipForRealIP,
|
||||
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
|
||||
},
|
||||
},
|
||||
"ExtractIPFromXFFHeader(default)": {
|
||||
ExtractIPFromXFFHeader(),
|
||||
map[string]string{
|
||||
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External,
|
||||
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyXFFInternal: ipForXFF3External,
|
||||
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
|
||||
},
|
||||
},
|
||||
"ExtractIPFromXFFHeader(trust only direct-facing proxy)": {
|
||||
ExtractIPFromXFFHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)),
|
||||
map[string]string{
|
||||
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyRealIPAndXFFExternal: ipForXFF1LinkLocal,
|
||||
ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyXFFExternal: ipForXFF1LinkLocal,
|
||||
ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
|
||||
},
|
||||
},
|
||||
"ExtractIPFromXFFHeader(trust direct-facing proxy)": {
|
||||
ExtractIPFromXFFHeader(TrustIPRange(ipForRemoteAddrExternalRange)),
|
||||
map[string]string{
|
||||
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyRealIPAndXFFExternal: ipForXFF3External,
|
||||
ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External,
|
||||
ipTestReqKeyXFFExternal: ipForXFF3External,
|
||||
ipTestReqKeyXFFInternal: ipForXFF3External,
|
||||
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
|
||||
},
|
||||
},
|
||||
"ExtractIPFromXFFHeader(trust everything)": {
|
||||
// This is similar to legacy behavior, but ignores x-real-ip header.
|
||||
ExtractIPFromXFFHeader(TrustIPRange(ipv4AllRange), TrustIPRange(ipv6AllRange)),
|
||||
map[string]string{
|
||||
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyRealIPAndXFFExternal: ipForXFF6External,
|
||||
ipTestReqKeyRealIPAndXFFInternal: ipForXFF6External,
|
||||
ipTestReqKeyXFFExternal: ipForXFF6External,
|
||||
ipTestReqKeyXFFInternal: ipForXFF6External,
|
||||
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
|
||||
},
|
||||
},
|
||||
"ExtractIPFromXFFHeader(trust ipForXFF3External)": {
|
||||
// This trusts private network also after "additional" trust ranges unlike `TrustNProxies(1)` doesn't
|
||||
ExtractIPFromXFFHeader(TrustIPRange(ipForXFF3ExternalRange)),
|
||||
map[string]string{
|
||||
ipTestReqKeyNoHeader: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback,
|
||||
ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyRealIPAndXFFInternal: ipForXFF5External,
|
||||
ipTestReqKeyXFFExternal: ipForRemoteAddrExternal,
|
||||
ipTestReqKeyXFFInternal: ipForXFF5External,
|
||||
ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback,
|
||||
{
|
||||
name: "ip is within trust range, trusts additional private IPV6 network",
|
||||
givenOptions: []TrustOption{
|
||||
TrustIPRange(mustParseCIDR("2001:db8::103/48")),
|
||||
},
|
||||
whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
|
||||
expect: true,
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := testify.New(t)
|
||||
for key, req := range requests {
|
||||
actual := test.extractor(req)
|
||||
expected := test.expectedIPs[key]
|
||||
assert.Equal(expected, actual, "Request: %s", key)
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
checker := newIPChecker(tc.givenOptions)
|
||||
|
||||
result := checker.trust(net.ParseIP(tc.whenIP))
|
||||
assert.Equal(t, tc.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustIPRange(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRange string
|
||||
whenIP string
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "ip is within trust range, IPV6 network range",
|
||||
// CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48
|
||||
// Address: 2001:0db8:0000:0000:0000:0000:0000:0103
|
||||
// Range start: 2001:0db8:0000:0000:0000:0000:0000:0000
|
||||
// Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff
|
||||
givenRange: "2001:db8::103/48",
|
||||
whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "ip is outside (upper bounds) of trust range, IPV6 network range",
|
||||
givenRange: "2001:db8::103/48",
|
||||
whenIP: "2001:0db8:0001:0000:0000:0000:0000:0000",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "ip is outside (lower bounds) of trust range, IPV6 network range",
|
||||
givenRange: "2001:db8::103/48",
|
||||
whenIP: "2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "ip is within trust range, IPV4 network range",
|
||||
// CIDR Notation: 8.8.8.8/24
|
||||
// Address: 8.8.8.8
|
||||
// Range start: 8.8.8.0
|
||||
// Range end: 8.8.8.255
|
||||
givenRange: "8.8.8.0/24",
|
||||
whenIP: "8.8.8.8",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "ip is within trust range, IPV4 network range",
|
||||
// CIDR Notation: 8.8.8.8/24
|
||||
// Address: 8.8.8.8
|
||||
// Range start: 8.8.8.0
|
||||
// Range end: 8.8.8.255
|
||||
givenRange: "8.8.8.0/24",
|
||||
whenIP: "8.8.8.8",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "ip is outside (upper bounds) of trust range, IPV4 network range",
|
||||
givenRange: "8.8.8.0/24",
|
||||
whenIP: "8.8.9.0",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "ip is outside (lower bounds) of trust range, IPV4 network range",
|
||||
givenRange: "8.8.8.0/24",
|
||||
whenIP: "8.8.7.255",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "public ip, trust everything in IPV4 network range",
|
||||
givenRange: "0.0.0.0/0",
|
||||
whenIP: "8.8.8.8",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "internal ip, trust everything in IPV4 network range",
|
||||
givenRange: "0.0.0.0/0",
|
||||
whenIP: "127.0.10.1",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "public ip, trust everything in IPV6 network range",
|
||||
givenRange: "::/0",
|
||||
whenIP: "2a00:1450:4026:805::200e",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "internal ip, trust everything in IPV6 network range",
|
||||
givenRange: "::/0",
|
||||
whenIP: "0:0:0:0:0:0:0:1",
|
||||
expect: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cidr := mustParseCIDR(tc.givenRange)
|
||||
checker := newIPChecker([]TrustOption{
|
||||
TrustLoopback(false), // disable to avoid interference
|
||||
TrustLinkLocal(false), // disable to avoid interference
|
||||
TrustPrivateNet(false), // disable to avoid interference
|
||||
|
||||
TrustIPRange(cidr),
|
||||
})
|
||||
|
||||
result := checker.trust(net.ParseIP(tc.whenIP))
|
||||
assert.Equal(t, tc.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustPrivateNet(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenIP string
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "do not trust public IPv4 address",
|
||||
whenIP: "8.8.8.8",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "do not trust public IPv6 address",
|
||||
whenIP: "2a00:1450:4026:805::200e",
|
||||
expect: false,
|
||||
},
|
||||
|
||||
{ // Class A: 10.0.0.0 — 10.255.255.255
|
||||
name: "do not trust IPv4 just outside of class A (lower bounds)",
|
||||
whenIP: "9.255.255.255",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "do not trust IPv4 just outside of class A (upper bounds)",
|
||||
whenIP: "11.0.0.0",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "trust IPv4 of class A (lower bounds)",
|
||||
whenIP: "10.0.0.0",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "trust IPv4 of class A (upper bounds)",
|
||||
whenIP: "10.255.255.255",
|
||||
expect: true,
|
||||
},
|
||||
|
||||
{ // Class B: 172.16.0.0 — 172.31.255.255
|
||||
name: "do not trust IPv4 just outside of class B (lower bounds)",
|
||||
whenIP: "172.15.255.255",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "do not trust IPv4 just outside of class B (upper bounds)",
|
||||
whenIP: "172.32.0.0",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "trust IPv4 of class B (lower bounds)",
|
||||
whenIP: "172.16.0.0",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "trust IPv4 of class B (upper bounds)",
|
||||
whenIP: "172.31.255.255",
|
||||
expect: true,
|
||||
},
|
||||
|
||||
{ // Class C: 192.168.0.0 — 192.168.255.255
|
||||
name: "do not trust IPv4 just outside of class C (lower bounds)",
|
||||
whenIP: "192.167.255.255",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "do not trust IPv4 just outside of class C (upper bounds)",
|
||||
whenIP: "192.169.0.0",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "trust IPv4 of class C (lower bounds)",
|
||||
whenIP: "192.168.0.0",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "trust IPv4 of class C (upper bounds)",
|
||||
whenIP: "192.168.255.255",
|
||||
expect: true,
|
||||
},
|
||||
|
||||
{ // fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA)
|
||||
// splits the address block in two equally sized halves, fc00::/8 and fd00::/8.
|
||||
// https://en.wikipedia.org/wiki/Unique_local_address
|
||||
name: "trust IPv6 private address",
|
||||
whenIP: "fdfc:3514:2cb3:4bd5::",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "do not trust IPv6 just out of /fd (upper bounds)",
|
||||
whenIP: "/fe00:0000:0000:0000:0000",
|
||||
expect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
checker := newIPChecker([]TrustOption{
|
||||
TrustLoopback(false), // disable to avoid interference
|
||||
TrustLinkLocal(false), // disable to avoid interference
|
||||
|
||||
TrustPrivateNet(true),
|
||||
})
|
||||
|
||||
result := checker.trust(net.ParseIP(tc.whenIP))
|
||||
assert.Equal(t, tc.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustLinkLocal(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenIP string
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "trust link local IPv4 address (lower bounds)",
|
||||
whenIP: "169.254.0.0",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "trust link local IPv4 address (upper bounds)",
|
||||
whenIP: "169.254.255.255",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "do not trust link local IPv4 address (outside of lower bounds)",
|
||||
whenIP: "169.253.255.255",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "do not trust link local IPv4 address (outside of upper bounds)",
|
||||
whenIP: "169.255.0.0",
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "trust link local IPv6 address ",
|
||||
whenIP: "fe80::1",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "do not trust link local IPv6 address ",
|
||||
whenIP: "fec0::1",
|
||||
expect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
checker := newIPChecker([]TrustOption{
|
||||
TrustLoopback(false), // disable to avoid interference
|
||||
TrustPrivateNet(false), // disable to avoid interference
|
||||
|
||||
TrustLinkLocal(true),
|
||||
})
|
||||
|
||||
result := checker.trust(net.ParseIP(tc.whenIP))
|
||||
assert.Equal(t, tc.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrustLoopback(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenIP string
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "trust IPv4 as localhost",
|
||||
whenIP: "127.0.0.1",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "trust IPv6 as localhost",
|
||||
whenIP: "::1",
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "do not trust public ip as localhost",
|
||||
whenIP: "8.8.8.8",
|
||||
expect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
checker := newIPChecker([]TrustOption{
|
||||
TrustLinkLocal(false), // disable to avoid interference
|
||||
TrustPrivateNet(false), // disable to avoid interference
|
||||
|
||||
TrustLoopback(true),
|
||||
})
|
||||
|
||||
result := checker.trust(net.ParseIP(tc.whenIP))
|
||||
assert.Equal(t, tc.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractIPDirect(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenRequest http.Request
|
||||
expectIP string
|
||||
}{
|
||||
{
|
||||
name: "request has no headers, extracts IP from request remote addr",
|
||||
whenRequest: http.Request{
|
||||
RemoteAddr: "203.0.113.1:8080",
|
||||
},
|
||||
expectIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXRealIP: []string{"203.0.113.10"},
|
||||
},
|
||||
RemoteAddr: "203.0.113.1:8080",
|
||||
},
|
||||
expectIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "request is from internal IP and has Real-IP header, extractor still extracts internal IP from request remote addr",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXRealIP: []string{"203.0.113.10"},
|
||||
},
|
||||
RemoteAddr: "127.0.0.1:8080",
|
||||
},
|
||||
expectIP: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "request is from external IP and has XFF + Real-IP header, extractor still extracts external IP from request remote addr",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXRealIP: []string{"203.0.113.10"},
|
||||
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
|
||||
},
|
||||
RemoteAddr: "203.0.113.1:8080",
|
||||
},
|
||||
expectIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "request is from internal IP and has XFF + Real-IP header, extractor still extracts internal IP from request remote addr",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXRealIP: []string{"127.0.0.1"},
|
||||
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
|
||||
},
|
||||
RemoteAddr: "127.0.0.1:8080",
|
||||
},
|
||||
expectIP: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "request is from external IP and has XFF header, extractor still extracts external IP from request remote addr",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
|
||||
},
|
||||
RemoteAddr: "203.0.113.1:8080",
|
||||
},
|
||||
expectIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "request is from internal IP and has XFF header, extractor still extracts internal IP from request remote addr",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
|
||||
},
|
||||
RemoteAddr: "127.0.0.1:8080",
|
||||
},
|
||||
expectIP: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "request is from internal IP and has INVALID XFF header, extractor still extracts internal IP from request remote addr",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXForwardedFor: []string{"this.is.broken.lol, 169.254.0.101"},
|
||||
},
|
||||
RemoteAddr: "127.0.0.1:8080",
|
||||
},
|
||||
expectIP: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
extractedIP := ExtractIPDirect()(&tc.whenRequest)
|
||||
assert.Equal(t, tc.expectIP, extractedIP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractIPFromRealIPHeader(t *testing.T) {
|
||||
_, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenTrustOptions []TrustOption
|
||||
whenRequest http.Request
|
||||
expectIP string
|
||||
}{
|
||||
{
|
||||
name: "request has no headers, extracts IP from request remote addr",
|
||||
whenRequest: http.Request{
|
||||
RemoteAddr: "203.0.113.1:8080",
|
||||
},
|
||||
expectIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "request is from external IP has INVALID external X-Real-Ip header, extract IP from remote addr",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXRealIP: []string{"xxx.yyy.zzz.ccc"}, // <-- this is invalid
|
||||
},
|
||||
RemoteAddr: "203.0.113.1:8080",
|
||||
},
|
||||
expectIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted
|
||||
},
|
||||
RemoteAddr: "203.0.113.1:8080",
|
||||
},
|
||||
expectIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
|
||||
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
|
||||
TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
|
||||
},
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXRealIP: []string{"203.0.113.199"},
|
||||
},
|
||||
RemoteAddr: "203.0.113.1:8080",
|
||||
},
|
||||
expectIP: "203.0.113.199",
|
||||
},
|
||||
{
|
||||
name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
|
||||
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
|
||||
TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
|
||||
},
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXRealIP: []string{"203.0.113.199"},
|
||||
HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything
|
||||
},
|
||||
RemoteAddr: "203.0.113.1:8080",
|
||||
},
|
||||
expectIP: "203.0.113.199",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
extractedIP := ExtractIPFromRealIPHeader(tc.givenTrustOptions...)(&tc.whenRequest)
|
||||
assert.Equal(t, tc.expectIP, extractedIP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractIPFromXFFHeader(t *testing.T) {
|
||||
_, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenTrustOptions []TrustOption
|
||||
whenRequest http.Request
|
||||
expectIP string
|
||||
}{
|
||||
{
|
||||
name: "request has no headers, extracts IP from request remote addr",
|
||||
whenRequest: http.Request{
|
||||
RemoteAddr: "203.0.113.1:8080",
|
||||
},
|
||||
expectIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "request has INVALID external XFF header, extract IP from remote addr",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXForwardedFor: []string{"xxx.yyy.zzz.ccc, 127.0.0.2"}, // <-- this is invalid
|
||||
},
|
||||
RemoteAddr: "127.0.0.1:8080",
|
||||
},
|
||||
expectIP: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXForwardedFor: []string{"127.0.0.3, 127.0.0.2, 127.0.0.1"},
|
||||
},
|
||||
RemoteAddr: "127.0.0.1:8080",
|
||||
},
|
||||
expectIP: "127.0.0.3",
|
||||
},
|
||||
{
|
||||
name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr",
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXForwardedFor: []string{"203.0.113.199"}, // <-- this is untrusted
|
||||
},
|
||||
RemoteAddr: "203.0.113.1:8080",
|
||||
},
|
||||
expectIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header",
|
||||
givenTrustOptions: []TrustOption{
|
||||
TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
|
||||
},
|
||||
// from request its seems that request has been proxied through 6 servers.
|
||||
// 1) 203.0.1.100 (this is external IP set by 203.0.100.100 which we do not trust - could be spoofed)
|
||||
// 2) 203.0.100.100 (this is outside of our network but set by 203.0.113.199 which we trust to set correct IPs)
|
||||
// 3) 203.0.113.199 (we trust, for example maybe our proxy from some other office)
|
||||
// 4) 192.168.1.100 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products)
|
||||
// 5) 127.0.0.1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing)
|
||||
whenRequest: http.Request{
|
||||
Header: http.Header{
|
||||
HeaderXForwardedFor: []string{"203.0.1.100, 203.0.100.100, 203.0.113.199, 192.168.1.100"},
|
||||
},
|
||||
RemoteAddr: "127.0.0.1:8080", // IP of proxy upstream of our APP
|
||||
},
|
||||
expectIP: "203.0.100.100", // this is first trusted IP in XFF chain
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
extractedIP := ExtractIPFromXFFHeader(tc.givenTrustOptions...)(&tc.whenRequest)
|
||||
assert.Equal(t, tc.expectIP, extractedIP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -72,9 +72,11 @@ func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Invalid base64 shouldn't be treated as error
|
||||
// instead should be treated as invalid client input
|
||||
b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
|
||||
if errDecode != nil {
|
||||
lastError = fmt.Errorf("invalid basic auth value: %w", errDecode)
|
||||
lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode)
|
||||
continue
|
||||
}
|
||||
idx := bytes.IndexByte(b, ':')
|
||||
|
@ -56,7 +56,7 @@ func TestBasicAuth(t *testing.T) {
|
||||
name: "nok, not base64 Authorization header",
|
||||
givenConfig: defaultConfig,
|
||||
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
|
||||
expectErr: "invalid basic auth value: illegal base64 data at input byte 3",
|
||||
expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3",
|
||||
},
|
||||
{
|
||||
name: "nok, missing Authorization header",
|
||||
|
@ -43,6 +43,7 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
|
||||
|
||||
// Based on content read (within limit)
|
||||
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
req.ContentLength = -1
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
|
||||
@ -55,6 +56,7 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
|
||||
|
||||
// Based on content read (overlimit)
|
||||
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
|
||||
req.ContentLength = -1
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
|
||||
|
@ -144,7 +144,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
var lastTokenErr error
|
||||
outer:
|
||||
for _, extractor := range extractors {
|
||||
clientTokens, err := extractor(c)
|
||||
clientTokens, _, err := extractor(c)
|
||||
if err != nil {
|
||||
lastExtractorErr = err
|
||||
continue
|
||||
|
@ -79,9 +79,6 @@ func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) {
|
||||
|
||||
func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
h := Decompress()(func(c echo.Context) error {
|
||||
c.Response().Write([]byte("test")) // For Content-Type sniffing
|
||||
@ -91,10 +88,10 @@ func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
|
||||
// Decompress
|
||||
body := `{"name": "echo"}`
|
||||
gz, _ := gzipString(body)
|
||||
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
|
||||
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := h(c)
|
||||
assert.NoError(t, err)
|
||||
|
@ -13,6 +13,24 @@ const (
|
||||
extractorLimit = 20
|
||||
)
|
||||
|
||||
// ExtractorSource is type to indicate source for extracted value
|
||||
type ExtractorSource string
|
||||
|
||||
const (
|
||||
// ExtractorSourceHeader means value was extracted from request header
|
||||
ExtractorSourceHeader ExtractorSource = "header"
|
||||
// ExtractorSourceQuery means value was extracted from request query parameters
|
||||
ExtractorSourceQuery ExtractorSource = "query"
|
||||
// ExtractorSourcePathParam means value was extracted from route path parameters
|
||||
ExtractorSourcePathParam ExtractorSource = "param"
|
||||
// ExtractorSourceCookie means value was extracted from request cookies
|
||||
ExtractorSourceCookie ExtractorSource = "cookie"
|
||||
// ExtractorSourceForm means value was extracted from request form values
|
||||
ExtractorSourceForm ExtractorSource = "form"
|
||||
// ExtractorSourceCustom means value was extracted by custom extractor
|
||||
ExtractorSourceCustom ExtractorSource = "custom"
|
||||
)
|
||||
|
||||
// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups
|
||||
type ValueExtractorError struct {
|
||||
message string
|
||||
@ -31,7 +49,7 @@ var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing valu
|
||||
var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"}
|
||||
|
||||
// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
|
||||
type ValuesExtractor func(c echo.Context) ([]string, error)
|
||||
type ValuesExtractor func(c echo.Context) ([]string, ExtractorSource, error)
|
||||
|
||||
func createExtractors(lookups string) ([]ValuesExtractor, error) {
|
||||
if lookups == "" {
|
||||
@ -75,10 +93,10 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
|
||||
prefixLen := len(valuePrefix)
|
||||
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
|
||||
header = textproto.CanonicalMIMEHeaderKey(header)
|
||||
return func(c echo.Context) ([]string, error) {
|
||||
return func(c echo.Context) ([]string, ExtractorSource, error) {
|
||||
values := c.Request().Header.Values(header)
|
||||
if len(values) == 0 {
|
||||
return nil, errHeaderExtractorValueMissing
|
||||
return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
|
||||
}
|
||||
|
||||
result := make([]string, 0)
|
||||
@ -100,30 +118,30 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
|
||||
|
||||
if len(result) == 0 {
|
||||
if prefixLen > 0 {
|
||||
return nil, errHeaderExtractorValueInvalid
|
||||
return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid
|
||||
}
|
||||
return nil, errHeaderExtractorValueMissing
|
||||
return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
|
||||
}
|
||||
return result, nil
|
||||
return result, ExtractorSourceHeader, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromQuery returns a function that extracts values from the query string.
|
||||
func valuesFromQuery(param string) ValuesExtractor {
|
||||
return func(c echo.Context) ([]string, error) {
|
||||
return func(c echo.Context) ([]string, ExtractorSource, error) {
|
||||
result := c.QueryParams()[param]
|
||||
if len(result) == 0 {
|
||||
return nil, errQueryExtractorValueMissing
|
||||
return nil, ExtractorSourceQuery, errQueryExtractorValueMissing
|
||||
} else if len(result) > extractorLimit-1 {
|
||||
result = result[:extractorLimit]
|
||||
}
|
||||
return result, nil
|
||||
return result, ExtractorSourceQuery, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromParam returns a function that extracts values from the url param string.
|
||||
func valuesFromParam(param string) ValuesExtractor {
|
||||
return func(c echo.Context) ([]string, error) {
|
||||
return func(c echo.Context) ([]string, ExtractorSource, error) {
|
||||
result := make([]string, 0)
|
||||
for i, p := range c.PathParams() {
|
||||
if param == p.Name {
|
||||
@ -134,18 +152,18 @@ func valuesFromParam(param string) ValuesExtractor {
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, errParamExtractorValueMissing
|
||||
return nil, ExtractorSourcePathParam, errParamExtractorValueMissing
|
||||
}
|
||||
return result, nil
|
||||
return result, ExtractorSourcePathParam, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromCookie returns a function that extracts values from the named cookie.
|
||||
func valuesFromCookie(name string) ValuesExtractor {
|
||||
return func(c echo.Context) ([]string, error) {
|
||||
return func(c echo.Context) ([]string, ExtractorSource, error) {
|
||||
cookies := c.Cookies()
|
||||
if len(cookies) == 0 {
|
||||
return nil, errCookieExtractorValueMissing
|
||||
return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
|
||||
}
|
||||
|
||||
result := make([]string, 0)
|
||||
@ -158,26 +176,26 @@ func valuesFromCookie(name string) ValuesExtractor {
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, errCookieExtractorValueMissing
|
||||
return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
|
||||
}
|
||||
return result, nil
|
||||
return result, ExtractorSourceCookie, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromForm returns a function that extracts values from the form field.
|
||||
func valuesFromForm(name string) ValuesExtractor {
|
||||
return func(c echo.Context) ([]string, error) {
|
||||
if parseErr := c.Request().ParseForm(); parseErr != nil {
|
||||
return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr)
|
||||
return func(c echo.Context) ([]string, ExtractorSource, error) {
|
||||
if c.Request().Form == nil {
|
||||
_ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does
|
||||
}
|
||||
values := c.Request().Form[name]
|
||||
if len(values) == 0 {
|
||||
return nil, errFormExtractorValueMissing
|
||||
return nil, ExtractorSourceForm, errFormExtractorValueMissing
|
||||
}
|
||||
if len(values) > extractorLimit-1 {
|
||||
values = values[:extractorLimit]
|
||||
}
|
||||
result := append([]string{}, values...)
|
||||
return result, nil
|
||||
return result, ExtractorSourceForm, nil
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,11 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@ -18,6 +20,7 @@ func TestCreateExtractors(t *testing.T) {
|
||||
givenPathParams echo.PathParams
|
||||
whenLoopups string
|
||||
expectValues []string
|
||||
expectSource ExtractorSource
|
||||
expectCreateError string
|
||||
expectError string
|
||||
}{
|
||||
@ -30,6 +33,7 @@ func TestCreateExtractors(t *testing.T) {
|
||||
},
|
||||
whenLoopups: "header:Authorization:Bearer ",
|
||||
expectValues: []string{"token"},
|
||||
expectSource: ExtractorSourceHeader,
|
||||
},
|
||||
{
|
||||
name: "ok, form",
|
||||
@ -43,6 +47,7 @@ func TestCreateExtractors(t *testing.T) {
|
||||
},
|
||||
whenLoopups: "form:name",
|
||||
expectValues: []string{"Jon Snow"},
|
||||
expectSource: ExtractorSourceForm,
|
||||
},
|
||||
{
|
||||
name: "ok, cookie",
|
||||
@ -53,6 +58,7 @@ func TestCreateExtractors(t *testing.T) {
|
||||
},
|
||||
whenLoopups: "cookie:_csrf",
|
||||
expectValues: []string{"token"},
|
||||
expectSource: ExtractorSourceCookie,
|
||||
},
|
||||
{
|
||||
name: "ok, param",
|
||||
@ -61,6 +67,7 @@ func TestCreateExtractors(t *testing.T) {
|
||||
},
|
||||
whenLoopups: "param:id",
|
||||
expectValues: []string{"123"},
|
||||
expectSource: ExtractorSourcePathParam,
|
||||
},
|
||||
{
|
||||
name: "ok, query",
|
||||
@ -70,6 +77,7 @@ func TestCreateExtractors(t *testing.T) {
|
||||
},
|
||||
whenLoopups: "query:id",
|
||||
expectValues: []string{"999"},
|
||||
expectSource: ExtractorSourceQuery,
|
||||
},
|
||||
{
|
||||
name: "nok, invalid lookup",
|
||||
@ -100,8 +108,9 @@ func TestCreateExtractors(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, e := range extractors {
|
||||
values, eErr := e(c)
|
||||
values, source, eErr := e(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, tc.expectSource, source)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, eErr, tc.expectError)
|
||||
return
|
||||
@ -226,8 +235,9 @@ func TestValuesFromHeader(t *testing.T) {
|
||||
|
||||
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
|
||||
|
||||
values, err := extractor(c)
|
||||
values, source, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, ExtractorSourceHeader, source)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
@ -287,8 +297,9 @@ func TestValuesFromQuery(t *testing.T) {
|
||||
|
||||
extractor := valuesFromQuery(tc.whenName)
|
||||
|
||||
values, err := extractor(c)
|
||||
values, source, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, ExtractorSourceQuery, source)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
@ -366,8 +377,9 @@ func TestValuesFromParam(t *testing.T) {
|
||||
|
||||
extractor := valuesFromParam(tc.whenName)
|
||||
|
||||
values, err := extractor(c)
|
||||
values, source, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, ExtractorSourcePathParam, source)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
@ -446,8 +458,9 @@ func TestValuesFromCookie(t *testing.T) {
|
||||
|
||||
extractor := valuesFromCookie(tc.whenName)
|
||||
|
||||
values, err := extractor(c)
|
||||
values, source, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, ExtractorSourceCookie, source)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
@ -483,6 +496,25 @@ func TestValuesFromForm(t *testing.T) {
|
||||
return req
|
||||
}
|
||||
|
||||
exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request {
|
||||
var b bytes.Buffer
|
||||
w := multipart.NewWriter(&b)
|
||||
w.WriteField("name", "Jon Snow")
|
||||
w.WriteField("emails[]", "jon@labstack.com")
|
||||
if mod != nil {
|
||||
mod(w)
|
||||
}
|
||||
|
||||
fw, _ := w.CreateFormFile("upload", "my.file")
|
||||
fw.Write([]byte(`<div>hi</div>`))
|
||||
w.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String()))
|
||||
req.Header.Add(echo.HeaderContentType, w.FormDataContentType())
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRequest *http.Request
|
||||
@ -504,6 +536,14 @@ func TestValuesFromForm(t *testing.T) {
|
||||
whenName: "emails[]",
|
||||
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
|
||||
},
|
||||
{
|
||||
name: "ok, POST multipart/form, multiple value",
|
||||
givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) {
|
||||
w.WriteField("emails[]", "snow@labstack.com")
|
||||
}),
|
||||
whenName: "emails[]",
|
||||
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
|
||||
},
|
||||
{
|
||||
name: "ok, GET form, single value",
|
||||
givenRequest: exampleGetFormRequest(nil),
|
||||
@ -524,16 +564,6 @@ func TestValuesFromForm(t *testing.T) {
|
||||
whenName: "nope",
|
||||
expectError: errFormExtractorValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "nok, POST form, form parsing error",
|
||||
givenRequest: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Body = nil
|
||||
return req
|
||||
}(),
|
||||
whenName: "name",
|
||||
expectError: "valuesFromForm parse form failed: missing form body",
|
||||
},
|
||||
{
|
||||
name: "ok, cut values over extractorLimit",
|
||||
givenRequest: examplePostFormRequest(func(v *url.Values) {
|
||||
@ -559,8 +589,9 @@ func TestValuesFromForm(t *testing.T) {
|
||||
|
||||
extractor := valuesFromForm(tc.whenName)
|
||||
|
||||
values, err := extractor(c)
|
||||
values, source, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
assert.Equal(t, ExtractorSourceForm, source)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
|
@ -64,7 +64,7 @@ type JWTConfig struct {
|
||||
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
|
||||
// parsing fails or parsed token is invalid.
|
||||
// Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library
|
||||
ParseTokenFunc func(c echo.Context, auth string) (interface{}, error)
|
||||
ParseTokenFunc func(c echo.Context, auth string, source ExtractorSource) (interface{}, error)
|
||||
}
|
||||
|
||||
// JWTSuccessHandler defines a function which is executed for a valid token.
|
||||
@ -101,7 +101,7 @@ var DefaultJWTConfig = JWTConfig{
|
||||
// For missing token, it returns "400 - Bad Request" error.
|
||||
//
|
||||
// See: https://jwt.io/introduction
|
||||
func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc {
|
||||
func JWT(parseTokenFunc func(c echo.Context, auth string, source ExtractorSource) (interface{}, error)) echo.MiddlewareFunc {
|
||||
c := DefaultJWTConfig
|
||||
c.ParseTokenFunc = parseTokenFunc
|
||||
return JWTWithConfig(c)
|
||||
@ -152,13 +152,13 @@ func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
var lastExtractorErr error
|
||||
var lastTokenErr error
|
||||
for _, extractor := range extractors {
|
||||
auths, extrErr := extractor(c)
|
||||
auths, source, extrErr := extractor(c)
|
||||
if extrErr != nil {
|
||||
lastExtractorErr = extrErr
|
||||
continue
|
||||
}
|
||||
for _, auth := range auths {
|
||||
token, err := config.ParseTokenFunc(c, auth)
|
||||
token, err := config.ParseTokenFunc(c, auth, source)
|
||||
if err != nil {
|
||||
lastTokenErr = err
|
||||
continue
|
||||
|
@ -21,7 +21,7 @@ import (
|
||||
// This is one of the options to provide a token validation key.
|
||||
// The order of precedence is a user-defined SigningKeys and SigningKey.
|
||||
// Required if signingKey is not provided
|
||||
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) {
|
||||
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string, source middleware.ExtractorSource) (interface{}, error) {
|
||||
// keyFunc defines a user-defined function that supplies the public key for a token validation.
|
||||
// The function shall take care of verifying the signing algorithm and selecting the proper key.
|
||||
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
|
||||
@ -41,7 +41,7 @@ func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]in
|
||||
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
|
||||
}
|
||||
|
||||
return func(c echo.Context, auth string) (interface{}, error) {
|
||||
return func(c echo.Context, auth string, source middleware.ExtractorSource) (interface{}, error) {
|
||||
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -14,7 +14,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) {
|
||||
func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
|
||||
// This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running
|
||||
keyFunc := func(t *jwt.Token) (interface{}, error) {
|
||||
if t.Method.Alg() != signingMethod {
|
||||
@ -23,7 +23,7 @@ func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface
|
||||
return signingKey, nil
|
||||
}
|
||||
|
||||
return func(c echo.Context, auth string) (interface{}, error) {
|
||||
return func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
|
||||
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -405,7 +405,6 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
|
||||
{
|
||||
name: "ok, ErrorHandler is executed",
|
||||
given: JWTConfig{
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
|
||||
ErrorHandler: func(c echo.Context, err error) error {
|
||||
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error())
|
||||
},
|
||||
@ -424,7 +423,7 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
|
||||
|
||||
config := tc.given
|
||||
parseTokenCalled := false
|
||||
config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) {
|
||||
config.ParseTokenFunc = func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
|
||||
parseTokenCalled = true
|
||||
return nil, errors.New("parsing failed")
|
||||
}
|
||||
@ -453,8 +452,10 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
|
||||
// with current JWT middleware
|
||||
signingKey := []byte("secret")
|
||||
|
||||
var fromSource ExtractorSource
|
||||
config := JWTConfig{
|
||||
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
|
||||
ParseTokenFunc: func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
|
||||
fromSource = source
|
||||
keyFunc := func(t *jwt.Token) (interface{}, error) {
|
||||
if t.Method.Alg() != "HS256" {
|
||||
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
|
||||
@ -481,6 +482,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, fromSource, ExtractorSourceHeader)
|
||||
assert.Equal(t, http.StatusTeapot, res.Code)
|
||||
}
|
||||
|
||||
@ -494,7 +496,7 @@ func TestMustJWTWithConfig_SuccessHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
mw, err := JWTConfig{
|
||||
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
|
||||
ParseTokenFunc: func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
|
||||
return auth, nil
|
||||
},
|
||||
SuccessHandler: func(c echo.Context) {
|
||||
@ -616,7 +618,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) {
|
||||
mw, err := JWTConfig{
|
||||
ContinueOnIgnoredError: tc.givenContinueOnIgnoredError,
|
||||
TokenLookup: tc.givenTokenLookup,
|
||||
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
|
||||
ParseTokenFunc: func(c echo.Context, auth string, source ExtractorSource) (interface{}, error) {
|
||||
return tc.whenParseReturn, tc.whenParseError
|
||||
},
|
||||
ErrorHandler: tc.givenErrorHandler,
|
||||
@ -648,8 +650,8 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
|
||||
e.Use(JWTWithConfig(JWTConfig{
|
||||
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
|
||||
TokenLookupFuncs: []ValuesExtractor{
|
||||
func(c echo.Context) ([]string, error) {
|
||||
return []string{c.Request().Header.Get("X-API-Key")}, nil
|
||||
func(c echo.Context) ([]string, ExtractorSource, error) {
|
||||
return []string{c.Request().Header.Get("X-API-Key")}, ExtractorSourceCustom, nil
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
@ -50,7 +50,7 @@ type KeyAuthConfig struct {
|
||||
}
|
||||
|
||||
// KeyAuthValidator defines a function to validate KeyAuth credentials.
|
||||
type KeyAuthValidator func(c echo.Context, key string) (bool, error)
|
||||
type KeyAuthValidator func(c echo.Context, key string, source ExtractorSource) (bool, error)
|
||||
|
||||
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
|
||||
type KeyAuthErrorHandler func(c echo.Context, err error) error
|
||||
@ -116,13 +116,13 @@ func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
var lastExtractorErr error
|
||||
var lastValidatorErr error
|
||||
for _, extractor := range extractors {
|
||||
keys, extrErr := extractor(c)
|
||||
keys, source, extrErr := extractor(c)
|
||||
if extrErr != nil {
|
||||
lastExtractorErr = extrErr
|
||||
continue
|
||||
}
|
||||
for _, key := range keys {
|
||||
valid, err := config.Validator(c, key)
|
||||
valid, err := config.Validator(c, key, source)
|
||||
if err != nil {
|
||||
lastValidatorErr = err
|
||||
continue
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testKeyValidator(c echo.Context, key string) (bool, error) {
|
||||
func testKeyValidator(c echo.Context, key string, source ExtractorSource) (bool, error) {
|
||||
switch key {
|
||||
case "valid-key":
|
||||
return true, nil
|
||||
@ -218,6 +218,25 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=401, message=Unauthorized, internal=some user defined error",
|
||||
},
|
||||
{
|
||||
name: "ok, custom validator checks source",
|
||||
givenRequest: func(req *http.Request) {
|
||||
q := req.URL.Query()
|
||||
q.Add("key", "valid-key")
|
||||
req.URL.RawQuery = q.Encode()
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.KeyLookup = "query:key"
|
||||
conf.Validator = func(c echo.Context, key string, source ExtractorSource) (bool, error) {
|
||||
if source == ExtractorSourceQuery {
|
||||
return true, nil
|
||||
}
|
||||
return false, errors.New("invalid source")
|
||||
}
|
||||
|
||||
},
|
||||
expectHandlerCalled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@ -267,7 +286,7 @@ func TestKeyAuthWithConfig_errors(t *testing.T) {
|
||||
{
|
||||
name: "ok, no error",
|
||||
whenConfig: KeyAuthConfig{
|
||||
Validator: func(c echo.Context, key string) (bool, error) {
|
||||
Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
},
|
||||
@ -283,7 +302,7 @@ func TestKeyAuthWithConfig_errors(t *testing.T) {
|
||||
name: "ok, extractor source can not be split",
|
||||
whenConfig: KeyAuthConfig{
|
||||
KeyLookup: "nope",
|
||||
Validator: func(c echo.Context, key string) (bool, error) {
|
||||
Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
},
|
||||
@ -293,7 +312,7 @@ func TestKeyAuthWithConfig_errors(t *testing.T) {
|
||||
name: "ok, no extractors",
|
||||
whenConfig: KeyAuthConfig{
|
||||
KeyLookup: "nope:nope",
|
||||
Validator: func(c echo.Context, key string) (bool, error) {
|
||||
Validator: func(c echo.Context, key string, source ExtractorSource) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
},
|
||||
|
@ -22,6 +22,8 @@ type LoggerConfig struct {
|
||||
// Tags to construct the logger format.
|
||||
//
|
||||
// - time_unix
|
||||
// - time_unix_milli
|
||||
// - time_unix_micro
|
||||
// - time_unix_nano
|
||||
// - time_rfc3339
|
||||
// - time_rfc3339_nano
|
||||
@ -119,6 +121,10 @@ func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
switch tag {
|
||||
case "time_unix":
|
||||
return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10))
|
||||
case "time_unix_milli":
|
||||
return buf.WriteString(strconv.FormatInt(time.Now().UnixMilli(), 10))
|
||||
case "time_unix_micro":
|
||||
return buf.WriteString(strconv.FormatInt(time.Now().UnixMicro(), 10))
|
||||
case "time_unix_nano":
|
||||
return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10))
|
||||
case "time_rfc3339":
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -172,6 +173,52 @@ func TestLoggerCustomTimestamp(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestLoggerTemplateWithTimeUnixMilli(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
e := echo.New()
|
||||
e.Use(LoggerWithConfig(LoggerConfig{
|
||||
Format: `${time_unix_milli}`,
|
||||
Output: buf,
|
||||
}))
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
unixMillis, err := strconv.ParseInt(buf.String(), 10, 64)
|
||||
assert.NoError(t, err)
|
||||
assert.WithinDuration(t, time.Unix(unixMillis/1000, 0), time.Now(), 3*time.Second)
|
||||
}
|
||||
|
||||
func TestLoggerTemplateWithTimeUnixMicro(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
e := echo.New()
|
||||
e.Use(LoggerWithConfig(LoggerConfig{
|
||||
Format: `${time_unix_micro}`,
|
||||
Output: buf,
|
||||
}))
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
unixMicros, err := strconv.ParseInt(buf.String(), 10, 64)
|
||||
assert.NoError(t, err)
|
||||
assert.WithinDuration(t, time.Unix(unixMicros/1000000, 0), time.Now(), 3*time.Second)
|
||||
}
|
||||
|
||||
func BenchmarkLoggerWithConfig_withoutMapFields(b *testing.B) {
|
||||
e := echo.New()
|
||||
|
||||
|
@ -32,7 +32,6 @@ func TestMethodOverride(t *testing.T) {
|
||||
|
||||
func TestMethodOverride_formParam(t *testing.T) {
|
||||
e := echo.New()
|
||||
m := MethodOverride()
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
@ -53,7 +52,6 @@ func TestMethodOverride_formParam(t *testing.T) {
|
||||
|
||||
func TestMethodOverride_queryParam(t *testing.T) {
|
||||
e := echo.New()
|
||||
m := MethodOverride()
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
@ -341,10 +341,9 @@ func TestProxyError(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Remote unreachable
|
||||
rec = httptest.NewRecorder()
|
||||
rec := httptest.NewRecorder()
|
||||
req.URL.Path = "/api/users"
|
||||
e.ServeHTTP(rec, req)
|
||||
assert.Equal(t, "/api/users", req.URL.Path)
|
||||
|
@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
@ -63,6 +64,9 @@ func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if r == http.ErrAbortHandler {
|
||||
panic(r)
|
||||
}
|
||||
tmpErr, ok := r.(error)
|
||||
if !ok {
|
||||
tmpErr = fmt.Errorf("%v", r)
|
||||
|
@ -51,6 +51,33 @@ func TestRecover_skipper(t *testing.T) {
|
||||
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
|
||||
}
|
||||
|
||||
func TestRecoverErrAbortHandler(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h := Recover()(func(c echo.Context) error {
|
||||
panic(http.ErrAbortHandler)
|
||||
})
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`")
|
||||
} else {
|
||||
if err, ok := r.(error); ok {
|
||||
assert.ErrorIs(t, err, http.ErrAbortHandler)
|
||||
} else {
|
||||
assert.Fail(t, "not of error type")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
hErr := h(c)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.NotContains(t, hErr.Error(), "PANIC RECOVER")
|
||||
}
|
||||
|
||||
func TestRecoverWithConfig(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
|
113
router.go
113
router.go
@ -8,6 +8,15 @@ import (
|
||||
)
|
||||
|
||||
// Router is interface for routing request contexts to registered routes.
|
||||
//
|
||||
// Contract between Echo/Context instance and the router:
|
||||
// * all routes must be added through methods on echo.Echo instance.
|
||||
// Reason: Echo instance uses RouteInfo.Params() length to allocate slice for paths parameters (see `Echo.contextPathParamAllocSize`).
|
||||
// * Router must populate Context during Router.Route call with:
|
||||
// * RoutableContext.SetPath
|
||||
// * RoutableContext.SetRawPathParams (IMPORTANT! with same slice pointer that c.RawPathParams() returns)
|
||||
// * RoutableContext.SetRouteInfo
|
||||
// And optionally can set additional information to Context with RoutableContext.Set
|
||||
type Router interface {
|
||||
// Add registers Routable with the Router and returns registered RouteInfo
|
||||
Add(routable Routable) (RouteInfo, error)
|
||||
@ -19,11 +28,6 @@ type Router interface {
|
||||
// Route searches Router for matching route and applies it to the given context. In case when no matching method
|
||||
// was not found (405) or no matching route exists for path (404), router will return its implementation of 405/404
|
||||
// handler function.
|
||||
// When implementing custom Router make sure to populate Context during Router.Route call with:
|
||||
// * RoutableContext.SetPath
|
||||
// * RoutableContext.SetRawPathParams (IMPORTANT! with same slice pointer that c.RawPathParams() returns)
|
||||
// * RoutableContext.SetRouteInfo
|
||||
// And optionally can set additional information to Context with RoutableContext.Set
|
||||
Route(c RoutableContext) HandlerFunc
|
||||
}
|
||||
|
||||
@ -210,18 +214,22 @@ type routeMethod struct {
|
||||
}
|
||||
|
||||
type routeMethods struct {
|
||||
connect *routeMethod
|
||||
delete *routeMethod
|
||||
get *routeMethod
|
||||
head *routeMethod
|
||||
options *routeMethod
|
||||
patch *routeMethod
|
||||
post *routeMethod
|
||||
propfind *routeMethod
|
||||
put *routeMethod
|
||||
trace *routeMethod
|
||||
report *routeMethod
|
||||
anyOther map[string]*routeMethod
|
||||
connect *routeMethod
|
||||
delete *routeMethod
|
||||
get *routeMethod
|
||||
head *routeMethod
|
||||
options *routeMethod
|
||||
patch *routeMethod
|
||||
post *routeMethod
|
||||
propfind *routeMethod
|
||||
put *routeMethod
|
||||
trace *routeMethod
|
||||
report *routeMethod
|
||||
anyOther map[string]*routeMethod
|
||||
|
||||
// notFoundHandler is handler registered with RouteNotFound method and is executed for 404 cases
|
||||
notFoundHandler *routeMethod
|
||||
|
||||
allowHeader string
|
||||
}
|
||||
|
||||
@ -249,6 +257,9 @@ func (m *routeMethods) set(method string, r *routeMethod) {
|
||||
m.trace = r
|
||||
case REPORT:
|
||||
m.report = r
|
||||
case RouteNotFound:
|
||||
m.notFoundHandler = r
|
||||
return // RouteNotFound/404 is not considered as a handler so no further logic needs to be executed
|
||||
default:
|
||||
if m.anyOther == nil {
|
||||
m.anyOther = make(map[string]*routeMethod)
|
||||
@ -353,6 +364,7 @@ func (m *routeMethods) isHandler() bool {
|
||||
m.trace != nil ||
|
||||
m.report != nil ||
|
||||
len(m.anyOther) != 0
|
||||
// RouteNotFound/404 is not considered as a handler
|
||||
}
|
||||
|
||||
// Routes returns all registered routes
|
||||
@ -611,7 +623,12 @@ func (r *DefaultRouter) insert(t kind, path string, method string, ri routeMetho
|
||||
}
|
||||
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
|
||||
} else if lcpLen < prefixLen {
|
||||
// Split node
|
||||
// Split node into two before we insert new node.
|
||||
// This happens when we are inserting path that is submatch of any existing inserted paths.
|
||||
// For example, we have node `/test` and now are about to insert `/te/*`. In that case
|
||||
// 1. overlapping part is `/te` that is used as parent node
|
||||
// 2. `st` is part from existing node that is not matching - it gets its own node (child to `/te`)
|
||||
// 3. `/*` is the new part we are about to insert (child to `/te`)
|
||||
n := newNode(
|
||||
currentNode.kind,
|
||||
currentNode.prefix[lcpLen:],
|
||||
@ -735,10 +752,8 @@ func (n *node) findStaticChild(l byte) *node {
|
||||
}
|
||||
|
||||
func (n *node) findChildWithLabel(l byte) *node {
|
||||
for _, c := range n.staticChildren {
|
||||
if c.label == l {
|
||||
return c
|
||||
}
|
||||
if c := n.findStaticChild(l); c != nil {
|
||||
return c
|
||||
}
|
||||
if l == paramLabel {
|
||||
return n.paramChild
|
||||
@ -751,11 +766,7 @@ func (n *node) findChildWithLabel(l byte) *node {
|
||||
|
||||
func (n *node) setHandler(method string, r *routeMethod) {
|
||||
n.methods.set(method, r)
|
||||
if r != nil && r.handler != nil {
|
||||
n.isHandler = true
|
||||
} else {
|
||||
n.isHandler = n.methods.isHandler()
|
||||
}
|
||||
n.isHandler = n.methods.isHandler()
|
||||
}
|
||||
|
||||
// Note: notFoundRouteInfo exists to avoid allocations when setting 404 RouteInfo to Context
|
||||
@ -900,7 +911,7 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
|
||||
// No matching prefix, let's backtrack to the first possible alternative node of the decision path
|
||||
nk, ok := backtrackToNextNodeKind(staticKind)
|
||||
if !ok {
|
||||
break // No other possibilities on the decision path
|
||||
break // No other possibilities on the decision path, handler will be whatever context is reset to.
|
||||
} else if nk == paramKind {
|
||||
goto Param
|
||||
// NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently
|
||||
@ -916,15 +927,21 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
|
||||
search = search[lcpLen:]
|
||||
searchIndex = searchIndex + lcpLen
|
||||
|
||||
// Finish routing if no remaining search and we are on a node with handler and matching method type
|
||||
if search == "" && currentNode.isHandler {
|
||||
// check if current node has handler registered for http method we are looking for. we store currentNode as
|
||||
// best matching in case we do no find no more routes matching this path+method
|
||||
if previousBestMatchNode == nil {
|
||||
previousBestMatchNode = currentNode
|
||||
}
|
||||
if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
|
||||
matchedRouteMethod = rMethod
|
||||
// Finish routing if is no request path remaining to search
|
||||
if search == "" {
|
||||
// in case of node that is handler we have exact method type match or something for 405 to use
|
||||
if currentNode.isHandler {
|
||||
// check if current node has handler registered for http method we are looking for. we store currentNode as
|
||||
// best matching in case we do no find no more routes matching this path+method
|
||||
if previousBestMatchNode == nil {
|
||||
previousBestMatchNode = currentNode
|
||||
}
|
||||
if h := currentNode.methods.find(req.Method); h != nil {
|
||||
matchedRouteMethod = h
|
||||
break
|
||||
}
|
||||
} else if currentNode.methods.notFoundHandler != nil {
|
||||
matchedRouteMethod = currentNode.methods.notFoundHandler
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -944,7 +961,8 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
|
||||
i := 0
|
||||
l := len(search)
|
||||
if currentNode.isLeaf {
|
||||
// when param node does not have any children then param node should act similarly to any node - consider all remaining search as match
|
||||
// when param node does not have any children (path param is last piece of route path) then param node should
|
||||
// act similarly to any node - consider all remaining search as match
|
||||
i = l
|
||||
} else {
|
||||
for ; i < l && search[i] != '/'; i++ {
|
||||
@ -969,13 +987,16 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
|
||||
searchIndex += +len(search)
|
||||
search = ""
|
||||
|
||||
// check if current node has handler registered for http method we are looking for. we store currentNode as
|
||||
// best matching in case we do no find no more routes matching this path+method
|
||||
if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
|
||||
matchedRouteMethod = rMethod
|
||||
break
|
||||
}
|
||||
// we store currentNode as best matching in case we do not find more routes matching this path+method. Needed for 405
|
||||
if previousBestMatchNode == nil {
|
||||
previousBestMatchNode = currentNode
|
||||
}
|
||||
if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
|
||||
matchedRouteMethod = rMethod
|
||||
if currentNode.methods.notFoundHandler != nil {
|
||||
matchedRouteMethod = currentNode.methods.notFoundHandler
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -1017,7 +1038,13 @@ func (r *DefaultRouter) Route(c RoutableContext) HandlerFunc {
|
||||
|
||||
rPath = currentNode.originalPath
|
||||
rInfo = notFoundRouteInfo
|
||||
if currentNode.isHandler {
|
||||
if currentNode.methods.notFoundHandler != nil {
|
||||
matchedRouteMethod = currentNode.methods.notFoundHandler
|
||||
|
||||
rInfo = matchedRouteMethod.routeInfo
|
||||
rPath = matchedRouteMethod.path
|
||||
rHandler = matchedRouteMethod.handler
|
||||
} else if currentNode.isHandler {
|
||||
rInfo = methodNotAllowedRouteInfo
|
||||
|
||||
c.Set(ContextKeyHeaderAllow, currentNode.methods.allowHeader)
|
||||
|
315
router_test.go
315
router_test.go
@ -1,7 +1,6 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -2243,6 +2242,64 @@ func TestRouter_Match_DifferentParamNamesForSamePlace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Issue #2164 - this test is meant to document path parameter behaviour when request url has empty value in place
|
||||
// of the path parameter. As tests show the result is different depending on where parameter exists in the route path.
|
||||
func TestDefaultRouter_PathParamsCanMatchEmptyValues(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenURL string
|
||||
expectRoute string
|
||||
expectParam map[string]string
|
||||
expectError error
|
||||
}{
|
||||
{
|
||||
name: "ok, route is matched with even empty param is in the middle and between slashes",
|
||||
whenURL: "/a//b",
|
||||
expectRoute: "/a/:id/b",
|
||||
expectParam: map[string]string{"id": ""},
|
||||
},
|
||||
{
|
||||
name: "ok, route is matched with even empty param is in the middle",
|
||||
whenURL: "/a2/b",
|
||||
expectRoute: "/a2:id/b",
|
||||
expectParam: map[string]string{"id": ""},
|
||||
},
|
||||
{
|
||||
name: "ok, route is NOT matched with even empty param is at the end",
|
||||
whenURL: "/a3/",
|
||||
expectRoute: "",
|
||||
expectParam: map[string]string{},
|
||||
expectError: ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
e.GET("/a/:id/b", handlerFunc)
|
||||
e.GET("/a2:id/b", handlerFunc)
|
||||
e.GET("/a3/:id", handlerFunc)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
c := e.NewContext(req, nil).(*DefaultContext)
|
||||
handler := e.router.Route(c)
|
||||
|
||||
err := handler(c)
|
||||
if tc.expectError != nil {
|
||||
assert.Equal(t, tc.expectError, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.expectRoute, c.Path())
|
||||
for param, expectedValue := range tc.expectParam {
|
||||
assert.Equal(t, expectedValue, c.pathParams.Get(param, ""))
|
||||
}
|
||||
checkUnusedParamValues(t, c, tc.expectParam)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Issue #729
|
||||
func TestRouterParamAlias(t *testing.T) {
|
||||
api := []testRoute{
|
||||
@ -3133,6 +3190,232 @@ func TestDefaultRouter_OptionsMethodHandler(t *testing.T) {
|
||||
assert.Equal(t, "not empty", body)
|
||||
}
|
||||
|
||||
func TestRouter_RouteWhenNotFoundRouteWithNodeSplitting(t *testing.T) {
|
||||
e := New()
|
||||
r := e.router
|
||||
|
||||
hf := func(c Context) error {
|
||||
return c.String(http.StatusOK, c.RouteInfo().Name())
|
||||
}
|
||||
r.Add(Route{Method: http.MethodGet, Path: "/test*", Handler: hf, Name: "0"})
|
||||
r.Add(Route{Method: RouteNotFound, Path: "/test*", Handler: hf, Name: "1"})
|
||||
r.Add(Route{Method: RouteNotFound, Path: "/test", Handler: hf, Name: "2"})
|
||||
|
||||
// Tree before:
|
||||
// 1 `/`
|
||||
// 1.1 `*` (any) ID=1
|
||||
// 1.2 `test` (static) ID=2
|
||||
// 1.2.1 `*` (any) ID=0
|
||||
|
||||
// node with path `test` has routeNotFound handler from previous Add call. Now when we insert `/te/st*` into router tree
|
||||
// This means that node `test` is split into `te` and `st` nodes and new node `/st*` is inserted.
|
||||
// On that split `/test` routeNotFound handler must not be lost.
|
||||
r.Add(Route{Method: http.MethodGet, Path: "/te/st*", Handler: hf, Name: "3"})
|
||||
// Tree after:
|
||||
// 1 `/`
|
||||
// 1.1 `*` (any) ID=1
|
||||
// 1.2 `te` (static)
|
||||
// 1.2.1 `st` (static) ID=2
|
||||
// 1.2.1.1 `*` (any) ID=0
|
||||
// 1.2.2 `/st` (static)
|
||||
// 1.2.2.1 `*` (any) ID=3
|
||||
|
||||
_, body := request(http.MethodPut, "/test", e)
|
||||
|
||||
assert.Equal(t, "2", body)
|
||||
}
|
||||
|
||||
func TestRouter_RouteWhenNotFoundRouteAnyKind(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenURL string
|
||||
expectRoute interface{}
|
||||
expectID int
|
||||
expectParam map[string]string
|
||||
}{
|
||||
{
|
||||
name: "route not existent /xx to not found handler /*",
|
||||
whenURL: "/xx",
|
||||
expectRoute: "/*",
|
||||
expectID: 4,
|
||||
expectParam: map[string]string{"*": "xx"},
|
||||
},
|
||||
{
|
||||
name: "route not existent /a/xx to not found handler /a/*",
|
||||
whenURL: "/a/xx",
|
||||
expectRoute: "/a/*",
|
||||
expectID: 5,
|
||||
expectParam: map[string]string{"*": "xx"},
|
||||
},
|
||||
{
|
||||
name: "route not existent /a/c/dxxx to not found handler /a/c/d*",
|
||||
whenURL: "/a/c/dxxx",
|
||||
expectRoute: "/a/c/d*",
|
||||
expectID: 6,
|
||||
expectParam: map[string]string{"*": "xxx"},
|
||||
},
|
||||
{
|
||||
name: "route /a/c/df to /a/c/df",
|
||||
whenURL: "/a/c/df",
|
||||
expectRoute: "/a/c/df",
|
||||
expectID: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
e.contextPathParamAllocSize = 1
|
||||
r := e.router
|
||||
|
||||
r.Add(Route{Method: http.MethodGet, Path: "/", Handler: handlerHelper("ID", 0), Name: "0"})
|
||||
r.Add(Route{Method: http.MethodGet, Path: "/a/c/df", Handler: handlerHelper("ID", 1), Name: "1"})
|
||||
r.Add(Route{Method: http.MethodGet, Path: "/a/b*", Handler: handlerHelper("ID", 2), Name: "2"})
|
||||
r.Add(Route{Method: http.MethodPut, Path: "/*", Handler: handlerHelper("ID", 3), Name: "3"})
|
||||
|
||||
r.Add(Route{Method: RouteNotFound, Path: "/a/c/d*", Handler: handlerHelper("ID", 6), Name: "6"})
|
||||
r.Add(Route{Method: RouteNotFound, Path: "/a/*", Handler: handlerHelper("ID", 5), Name: "5"})
|
||||
r.Add(Route{Method: RouteNotFound, Path: "/*", Handler: handlerHelper("ID", 4), Name: "4"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
c := e.NewContext(req, nil).(*DefaultContext)
|
||||
|
||||
handler := r.Route(c)
|
||||
handler(c)
|
||||
|
||||
testValue, _ := c.Get("ID").(int)
|
||||
assert.Equal(t, tc.expectID, testValue)
|
||||
assert.Equal(t, tc.expectRoute, c.Path())
|
||||
for param, expectedValue := range tc.expectParam {
|
||||
assert.Equal(t, expectedValue, c.PathParam(param))
|
||||
}
|
||||
checkUnusedParamValues(t, c, tc.expectParam)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_RouteWhenNotFoundRouteParamKind(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenURL string
|
||||
expectRoute interface{}
|
||||
expectID int
|
||||
expectParam map[string]string
|
||||
}{
|
||||
{
|
||||
name: "route not existent /xx to not found handler /:file",
|
||||
whenURL: "/xx",
|
||||
expectRoute: "/:file",
|
||||
expectID: 4,
|
||||
expectParam: map[string]string{"file": "xx"},
|
||||
},
|
||||
{
|
||||
name: "route not existent /a/xx to not found handler /a/:file",
|
||||
whenURL: "/a/xx",
|
||||
expectRoute: "/a/:file",
|
||||
expectID: 5,
|
||||
expectParam: map[string]string{"file": "xx"},
|
||||
},
|
||||
{
|
||||
name: "route not existent /a/c/dxxx to not found handler /a/c/d:file",
|
||||
whenURL: "/a/c/dxxx",
|
||||
expectRoute: "/a/c/d:file",
|
||||
expectID: 6,
|
||||
expectParam: map[string]string{"file": "xxx"},
|
||||
},
|
||||
{
|
||||
name: "route /a/c/df to /a/c/df",
|
||||
whenURL: "/a/c/df",
|
||||
expectRoute: "/a/c/df",
|
||||
expectID: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
e.contextPathParamAllocSize = 1
|
||||
r := e.router
|
||||
|
||||
r.Add(Route{Method: http.MethodGet, Path: "/", Handler: handlerHelper("ID", 0), Name: "0"})
|
||||
r.Add(Route{Method: http.MethodGet, Path: "/a/c/df", Handler: handlerHelper("ID", 1), Name: "1"})
|
||||
r.Add(Route{Method: http.MethodGet, Path: "/a/b*", Handler: handlerHelper("ID", 2), Name: "2"})
|
||||
r.Add(Route{Method: http.MethodPut, Path: "/*", Handler: handlerHelper("ID", 3), Name: "3"})
|
||||
|
||||
r.Add(Route{Method: RouteNotFound, Path: "/a/c/d:file", Handler: handlerHelper("ID", 6), Name: "6"})
|
||||
r.Add(Route{Method: RouteNotFound, Path: "/a/:file", Handler: handlerHelper("ID", 5), Name: "5"})
|
||||
r.Add(Route{Method: RouteNotFound, Path: "/:file", Handler: handlerHelper("ID", 4), Name: "4"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
c := e.NewContext(req, nil).(*DefaultContext)
|
||||
|
||||
handler := r.Route(c)
|
||||
handler(c)
|
||||
|
||||
testValue, _ := c.Get("ID").(int)
|
||||
assert.Equal(t, tc.expectID, testValue)
|
||||
assert.Equal(t, tc.expectRoute, c.Path())
|
||||
for param, expectedValue := range tc.expectParam {
|
||||
assert.Equal(t, expectedValue, c.PathParam(param))
|
||||
}
|
||||
checkUnusedParamValues(t, c, tc.expectParam)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_RouteWhenNotFoundRouteStaticKind(t *testing.T) {
|
||||
// note: static not found handler is quite silly thing to have but we still support it
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenURL string
|
||||
expectRoute interface{}
|
||||
expectID int
|
||||
expectParam map[string]string
|
||||
}{
|
||||
{
|
||||
name: "route not existent / to not found handler /",
|
||||
whenURL: "/",
|
||||
expectRoute: "/",
|
||||
expectID: 3,
|
||||
expectParam: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "route /a to /a",
|
||||
whenURL: "/a",
|
||||
expectRoute: "/a",
|
||||
expectID: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := New()
|
||||
e.contextPathParamAllocSize = 1
|
||||
r := e.router
|
||||
|
||||
r.Add(Route{Method: http.MethodPut, Path: "/", Handler: handlerHelper("ID", 0), Name: "0"})
|
||||
r.Add(Route{Method: http.MethodGet, Path: "/a", Handler: handlerHelper("ID", 1), Name: "1"})
|
||||
r.Add(Route{Method: http.MethodPut, Path: "/*", Handler: handlerHelper("ID", 2), Name: "2"})
|
||||
|
||||
r.Add(Route{Method: RouteNotFound, Path: "/", Handler: handlerHelper("ID", 3), Name: "3"})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
|
||||
c := e.NewContext(req, nil).(*DefaultContext)
|
||||
|
||||
handler := r.Route(c)
|
||||
handler(c)
|
||||
|
||||
testValue, _ := c.Get("ID").(int)
|
||||
assert.Equal(t, tc.expectID, testValue)
|
||||
assert.Equal(t, tc.expectRoute, c.Path())
|
||||
for param, expectedValue := range tc.expectParam {
|
||||
assert.Equal(t, expectedValue, c.PathParam(param))
|
||||
}
|
||||
checkUnusedParamValues(t, c, tc.expectParam)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mySimpleRouter struct {
|
||||
route Route
|
||||
}
|
||||
@ -3227,33 +3510,3 @@ func BenchmarkRouterGooglePlusAPIMisses(b *testing.B) {
|
||||
func BenchmarkRouterParamsAndAnyAPI(b *testing.B) {
|
||||
benchmarkRouterRoutes(b, paramAndAnyAPI, paramAndAnyAPIToFind)
|
||||
}
|
||||
|
||||
func (n *node) printTree(pfx string, tail bool) {
|
||||
p := prefix(tail, pfx, "└── ", "├── ")
|
||||
fmt.Printf("%s%s, %p: type=%d, parent=%p, handler=%v, paramNames=%v\n", p, n.prefix, n, n.kind, n.parent, n.methods, n.paramsCount)
|
||||
|
||||
p = prefix(tail, pfx, " ", "│ ")
|
||||
|
||||
children := n.staticChildren
|
||||
l := len(children)
|
||||
|
||||
if n.paramChild != nil {
|
||||
n.paramChild.printTree(p, n.anyChild == nil && l == 0)
|
||||
}
|
||||
if n.anyChild != nil {
|
||||
n.anyChild.printTree(p, l == 0)
|
||||
}
|
||||
for i := 0; i < l-1; i++ {
|
||||
children[i].printTree(p, false)
|
||||
}
|
||||
if l > 0 {
|
||||
children[l-1].printTree(p, true)
|
||||
}
|
||||
}
|
||||
|
||||
func prefix(tail bool, p, on, off string) string {
|
||||
if tail {
|
||||
return fmt.Sprintf("%s%s", p, on)
|
||||
}
|
||||
return fmt.Sprintf("%s%s", p, off)
|
||||
}
|
||||
|
@ -696,8 +696,7 @@ func TestStartConfig_WithHidePort(t *testing.T) {
|
||||
}
|
||||
assert.NoError(t, <-errCh)
|
||||
|
||||
portMsg := fmt.Sprintf("http(s) server started on")
|
||||
contains := strings.Contains(buf.String(), portMsg)
|
||||
contains := strings.Contains(buf.String(), "http(s) server started on")
|
||||
if tc.hidePort {
|
||||
assert.False(t, contains)
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user