mirror of
https://github.com/labstack/echo.git
synced 2024-11-24 08:22:21 +02:00
parent
90d675fa2a
commit
f49d166e6f
2
bind.go
2
bind.go
@ -80,7 +80,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
|
||||
val := reflect.ValueOf(ptr).Elem()
|
||||
|
||||
if typ.Kind() != reflect.Struct {
|
||||
return errors.New("Binding element must be a struct")
|
||||
return errors.New("binding element must be a struct")
|
||||
}
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
|
@ -135,8 +135,7 @@ func TestBindForm(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
req.Header.Set(HeaderContentType, MIMEApplicationForm)
|
||||
obj := []struct{ Field string }{}
|
||||
err := c.Bind(&obj)
|
||||
err := c.Bind(&[]struct{ Field string }{})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
|
@ -2,21 +2,18 @@ package echo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"strings"
|
||||
|
||||
"net/url"
|
||||
|
||||
"encoding/xml"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -217,7 +214,7 @@ func TestContext(t *testing.T) {
|
||||
c.SetParamNames("foo")
|
||||
c.SetParamValues("bar")
|
||||
c.Set("foe", "ban")
|
||||
c.query = url.Values(map[string][]string{"fon": []string{"baz"}})
|
||||
c.query = url.Values(map[string][]string{"fon": {"baz"}})
|
||||
c.Reset(req, httptest.NewRecorder())
|
||||
assert.Equal(t, 0, len(c.ParamValues()))
|
||||
assert.Equal(t, 0, len(c.ParamNames()))
|
||||
|
18
echo.go
18
echo.go
@ -251,10 +251,10 @@ var (
|
||||
ErrForbidden = NewHTTPError(http.StatusForbidden)
|
||||
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
|
||||
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
|
||||
ErrValidatorNotRegistered = errors.New("Validator not registered")
|
||||
ErrRendererNotRegistered = errors.New("Renderer not registered")
|
||||
ErrInvalidRedirectCode = errors.New("Invalid redirect status code")
|
||||
ErrCookieNotFound = errors.New("Cookie not found")
|
||||
ErrValidatorNotRegistered = errors.New("validator not registered")
|
||||
ErrRendererNotRegistered = errors.New("renderer not registered")
|
||||
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
|
||||
ErrCookieNotFound = errors.New("cookie not found")
|
||||
)
|
||||
|
||||
// Error handlers
|
||||
@ -530,7 +530,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
|
||||
|
||||
// Routes returns the registered routes.
|
||||
func (e *Echo) Routes() []*Route {
|
||||
routes := []*Route{}
|
||||
routes := make([]*Route, 0, len(e.router.routes))
|
||||
for _, v := range e.router.routes {
|
||||
routes = append(routes, v)
|
||||
}
|
||||
@ -563,11 +563,11 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Middleware
|
||||
h := func(c Context) error {
|
||||
method := r.Method
|
||||
path := r.URL.RawPath
|
||||
if path == "" {
|
||||
path = r.URL.Path
|
||||
rpath := r.URL.RawPath // raw path
|
||||
if rpath == "" {
|
||||
rpath = r.URL.Path
|
||||
}
|
||||
e.router.Find(method, path, c)
|
||||
e.router.Find(method, rpath, c)
|
||||
h := c.Handler()
|
||||
for i := len(e.middleware) - 1; i >= 0; i-- {
|
||||
h = e.middleware[i](h)
|
||||
|
@ -2,15 +2,12 @@ package echo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"errors"
|
||||
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
4
group.go
4
group.go
@ -92,7 +92,7 @@ func (g *Group) Match(methods []string, path string, handler HandlerFunc, middle
|
||||
|
||||
// Group creates a new sub-group with prefix and optional sub-group-level middleware.
|
||||
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) *Group {
|
||||
m := []MiddlewareFunc{}
|
||||
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
|
||||
m = append(m, g.middleware...)
|
||||
m = append(m, middleware...)
|
||||
return g.echo.Group(g.prefix+prefix, m...)
|
||||
@ -113,7 +113,7 @@ func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...Midd
|
||||
// Combine into a new slice to avoid accidentally passing the same slice for
|
||||
// multiple routes, which would lead to later add() calls overwriting the
|
||||
// middleware from earlier calls.
|
||||
m := []MiddlewareFunc{}
|
||||
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
|
||||
m = append(m, g.middleware...)
|
||||
m = append(m, middleware...)
|
||||
return g.echo.Add(method, g.prefix+path, handler, m...)
|
||||
|
@ -93,10 +93,8 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
realm := ""
|
||||
if config.Realm == defaultRealm {
|
||||
realm = defaultRealm
|
||||
} else {
|
||||
realm := defaultRealm
|
||||
if config.Realm != defaultRealm {
|
||||
realm = strconv.Quote(config.Realm)
|
||||
}
|
||||
|
||||
|
@ -31,6 +31,19 @@ func TestBasicAuth(t *testing.T) {
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(t, h(c))
|
||||
|
||||
h = BasicAuthWithConfig(BasicAuthConfig{
|
||||
Skipper: nil,
|
||||
Validator: f,
|
||||
Realm: "someRealm",
|
||||
})(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
// Valid credentials
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(t, h(c))
|
||||
|
||||
// Case-insensitive header scheme
|
||||
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
@ -41,7 +54,7 @@ func TestBasicAuth(t *testing.T) {
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
|
||||
// Missing Authorization header
|
||||
req.Header.Del(echo.HeaderAuthorization)
|
||||
|
@ -3,12 +3,11 @@ package middleware
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"io"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -31,10 +32,65 @@ func TestBodyDump(t *testing.T) {
|
||||
requestBody = string(reqBody)
|
||||
responseBody = string(resBody)
|
||||
})
|
||||
|
||||
if assert.NoError(t, mw(h)(c)) {
|
||||
assert.Equal(t, requestBody, hw)
|
||||
assert.Equal(t, responseBody, hw)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, hw, rec.Body.String())
|
||||
}
|
||||
|
||||
// Must set default skipper
|
||||
BodyDumpWithConfig(BodyDumpConfig{
|
||||
Skipper: nil,
|
||||
Handler: func(c echo.Context, reqBody, resBody []byte) {
|
||||
requestBody = string(reqBody)
|
||||
responseBody = string(resBody)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBodyDumpFails(t *testing.T) {
|
||||
e := echo.New()
|
||||
hw := "Hello, World!"
|
||||
req := httptest.NewRequest(echo.POST, "/", strings.NewReader(hw))
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
h := func(c echo.Context) error {
|
||||
return errors.New("some error")
|
||||
}
|
||||
|
||||
requestBody := ""
|
||||
responseBody := ""
|
||||
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
|
||||
requestBody = string(reqBody)
|
||||
responseBody = string(resBody)
|
||||
})
|
||||
|
||||
if !assert.Error(t, mw(h)(c)) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
assert.Panics(t, func() {
|
||||
mw = BodyDumpWithConfig(BodyDumpConfig{
|
||||
Skipper: nil,
|
||||
Handler: nil,
|
||||
})
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
mw = BodyDumpWithConfig(BodyDumpConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return true
|
||||
},
|
||||
Handler: func(c echo.Context, reqBody, resBody []byte) {
|
||||
requestBody = string(reqBody)
|
||||
responseBody = string(resBody)
|
||||
},
|
||||
})
|
||||
|
||||
if !assert.Error(t, mw(h)(c)) {
|
||||
t.FailNow()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -36,8 +36,8 @@ func TestGzip(t *testing.T) {
|
||||
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
|
||||
r, err := gzip.NewReader(rec.Body)
|
||||
if assert.NoError(t, err) {
|
||||
defer r.Close()
|
||||
buf := new(bytes.Buffer)
|
||||
defer r.Close()
|
||||
buf.ReadFrom(r)
|
||||
assert.Equal(t, "test", buf.String())
|
||||
}
|
||||
|
@ -124,10 +124,11 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
|
||||
req := c.Request()
|
||||
k, err := c.Cookie(config.CookieName)
|
||||
token := ""
|
||||
|
||||
var token string
|
||||
|
||||
// Generate token
|
||||
if err != nil {
|
||||
// Generate token
|
||||
token = random.String(config.TokenLength)
|
||||
} else {
|
||||
// Reuse token
|
||||
@ -187,7 +188,7 @@ func csrfTokenFromForm(param string) csrfTokenExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
token := c.FormValue(param)
|
||||
if token == "" {
|
||||
return "", errors.New("Missing csrf token in the form parameter")
|
||||
return "", errors.New("missing csrf token in the form parameter")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
@ -199,7 +200,7 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
token := c.QueryParam(param)
|
||||
if token == "" {
|
||||
return "", errors.New("Missing csrf token in the query string")
|
||||
return "", errors.New("missing csrf token in the query string")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
||||
config.keyFunc = func(t *jwt.Token) (interface{}, error) {
|
||||
// Check the signing method
|
||||
if t.Method.Alg() != config.SigningMethod {
|
||||
return nil, fmt.Errorf("Unexpected jwt signing method=%v", t.Header["alg"])
|
||||
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
|
||||
}
|
||||
return config.SigningKey, nil
|
||||
}
|
||||
|
@ -114,14 +114,14 @@ func keyFromHeader(header string, authScheme string) keyExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
auth := c.Request().Header.Get(header)
|
||||
if auth == "" {
|
||||
return "", errors.New("Missing key in request header")
|
||||
return "", errors.New("missing key in request header")
|
||||
}
|
||||
if header == echo.HeaderAuthorization {
|
||||
l := len(authScheme)
|
||||
if len(auth) > l+1 && auth[:l] == authScheme {
|
||||
return auth[l+1:], nil
|
||||
}
|
||||
return "", errors.New("Invalid key in the request header")
|
||||
return "", errors.New("invalid key in the request header")
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
@ -132,7 +132,7 @@ func keyFromQuery(param string) keyExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
key := c.QueryParam(param)
|
||||
if key == "" {
|
||||
return "", errors.New("Missing key in the query string")
|
||||
return "", errors.New("missing key in the query string")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
@ -143,7 +143,7 @@ func keyFromForm(param string) keyExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
key := c.FormValue(param)
|
||||
if key == "" {
|
||||
return "", errors.New("Missing key in the form")
|
||||
return "", errors.New("missing key in the form")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ type (
|
||||
// Example "${remote_ip} ${status}"
|
||||
//
|
||||
// Optional. Default value DefaultLoggerConfig.Format.
|
||||
Format string `yaml:"format"`
|
||||
Format string `yaml:"format"`
|
||||
|
||||
// Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
|
||||
CustomTimeFormat string `yaml:"custom_time_format"`
|
||||
@ -70,9 +70,9 @@ var (
|
||||
`"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` +
|
||||
`"latency_human":"${latency_human}","bytes_in":${bytes_in},` +
|
||||
`"bytes_out":${bytes_out}}` + "\n",
|
||||
CustomTimeFormat:"2006-01-02 15:04:05.00000",
|
||||
Output: os.Stdout,
|
||||
colorer: color.New(),
|
||||
CustomTimeFormat: "2006-01-02 15:04:05.00000",
|
||||
Output: os.Stdout,
|
||||
colorer: color.New(),
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -2,18 +2,18 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"encoding/json"
|
||||
"github.com/labstack/echo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLogger(t *testing.T) {
|
||||
@ -152,7 +152,7 @@ func TestLoggerCustomTimestamp(t *testing.T) {
|
||||
`"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}",` +
|
||||
`"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n",
|
||||
CustomTimeFormat: customTimeFormat,
|
||||
Output: buf,
|
||||
Output: buf,
|
||||
}))
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
|
@ -108,15 +108,15 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
errCh := make(chan error, 2)
|
||||
cp := func(dst io.Writer, src io.Reader) {
|
||||
_, err := io.Copy(dst, src)
|
||||
errc <- err
|
||||
_, err = io.Copy(dst, src)
|
||||
errCh <- err
|
||||
}
|
||||
|
||||
go cp(out, in)
|
||||
go cp(in, out)
|
||||
err = <-errc
|
||||
err = <-errCh
|
||||
if err != nil && err != io.EOF {
|
||||
c.Logger().Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)
|
||||
}
|
||||
|
@ -4,9 +4,8 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -48,14 +47,25 @@ func TestProxy(t *testing.T) {
|
||||
url2, _ := url.Parse(t2.URL)
|
||||
|
||||
targets := []*ProxyTarget{
|
||||
&ProxyTarget{
|
||||
URL: url1,
|
||||
{
|
||||
Name: "target 1",
|
||||
URL: url1,
|
||||
},
|
||||
&ProxyTarget{
|
||||
URL: url2,
|
||||
{
|
||||
Name: "target 2",
|
||||
URL: url2,
|
||||
},
|
||||
}
|
||||
rb := NewRandomBalancer(targets)
|
||||
rb := NewRandomBalancer(nil)
|
||||
// must add targets:
|
||||
for _, target := range targets {
|
||||
assert.True(t, rb.AddTarget(target))
|
||||
}
|
||||
|
||||
// must ignore duplicates:
|
||||
for _, target := range targets {
|
||||
assert.False(t, rb.AddTarget(target))
|
||||
}
|
||||
|
||||
// Random
|
||||
e := echo.New()
|
||||
@ -72,6 +82,12 @@ func TestProxy(t *testing.T) {
|
||||
return expected[body]
|
||||
})
|
||||
|
||||
for _, target := range targets {
|
||||
assert.True(t, rb.RemoveTarget(target.Name))
|
||||
}
|
||||
|
||||
assert.False(t, rb.RemoveTarget("unknown target"))
|
||||
|
||||
// Round-robin
|
||||
rrb := NewRoundRobinBalancer(targets)
|
||||
e = echo.New()
|
||||
|
@ -59,8 +59,8 @@ func (r *Router) Add(method, path string, h HandlerFunc) {
|
||||
if path[0] != '/' {
|
||||
path = "/" + path
|
||||
}
|
||||
ppath := path // Pristine path
|
||||
pnames := []string{} // Param names
|
||||
ppath := path // Pristine path
|
||||
var pnames []string // Param names
|
||||
|
||||
for i, l := 0, len(path); i < l; i++ {
|
||||
if path[i] == ':' {
|
||||
|
Loading…
Reference in New Issue
Block a user