1
0
mirror of https://github.com/labstack/echo.git synced 2025-05-13 22:06:36 +02:00

[FIX] Cleanup code (#1061)

Code cleanup
This commit is contained in:
Evgeniy Kulikov 2018-02-21 21:44:17 +03:00 committed by Vishal Rana
parent 90d675fa2a
commit f49d166e6f
19 changed files with 141 additions and 65 deletions

View File

@ -80,7 +80,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
val := reflect.ValueOf(ptr).Elem() val := reflect.ValueOf(ptr).Elem()
if typ.Kind() != reflect.Struct { if typ.Kind() != reflect.Struct {
return errors.New("Binding element must be a struct") return errors.New("binding element must be a struct")
} }
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {

View File

@ -135,8 +135,7 @@ func TestBindForm(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
req.Header.Set(HeaderContentType, MIMEApplicationForm) req.Header.Set(HeaderContentType, MIMEApplicationForm)
obj := []struct{ Field string }{} err := c.Bind(&[]struct{ Field string }{})
err := c.Bind(&obj)
assert.Error(t, err) assert.Error(t, err)
} }

View File

@ -2,21 +2,18 @@ package echo
import ( import (
"bytes" "bytes"
"encoding/xml"
"errors" "errors"
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strings"
"testing" "testing"
"text/template" "text/template"
"time" "time"
"strings"
"net/url"
"encoding/xml"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -217,7 +214,7 @@ func TestContext(t *testing.T) {
c.SetParamNames("foo") c.SetParamNames("foo")
c.SetParamValues("bar") c.SetParamValues("bar")
c.Set("foe", "ban") c.Set("foe", "ban")
c.query = url.Values(map[string][]string{"fon": []string{"baz"}}) c.query = url.Values(map[string][]string{"fon": {"baz"}})
c.Reset(req, httptest.NewRecorder()) c.Reset(req, httptest.NewRecorder())
assert.Equal(t, 0, len(c.ParamValues())) assert.Equal(t, 0, len(c.ParamValues()))
assert.Equal(t, 0, len(c.ParamNames())) assert.Equal(t, 0, len(c.ParamNames()))

18
echo.go
View File

@ -251,10 +251,10 @@ var (
ErrForbidden = NewHTTPError(http.StatusForbidden) ErrForbidden = NewHTTPError(http.StatusForbidden)
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
ErrValidatorNotRegistered = errors.New("Validator not registered") ErrValidatorNotRegistered = errors.New("validator not registered")
ErrRendererNotRegistered = errors.New("Renderer not registered") ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("Invalid redirect status code") ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("Cookie not found") ErrCookieNotFound = errors.New("cookie not found")
) )
// Error handlers // Error handlers
@ -530,7 +530,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
// Routes returns the registered routes. // Routes returns the registered routes.
func (e *Echo) Routes() []*Route { func (e *Echo) Routes() []*Route {
routes := []*Route{} routes := make([]*Route, 0, len(e.router.routes))
for _, v := range e.router.routes { for _, v := range e.router.routes {
routes = append(routes, v) routes = append(routes, v)
} }
@ -563,11 +563,11 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Middleware // Middleware
h := func(c Context) error { h := func(c Context) error {
method := r.Method method := r.Method
path := r.URL.RawPath rpath := r.URL.RawPath // raw path
if path == "" { if rpath == "" {
path = r.URL.Path rpath = r.URL.Path
} }
e.router.Find(method, path, c) e.router.Find(method, rpath, c)
h := c.Handler() h := c.Handler()
for i := len(e.middleware) - 1; i >= 0; i-- { for i := len(e.middleware) - 1; i >= 0; i-- {
h = e.middleware[i](h) h = e.middleware[i](h)

View File

@ -2,15 +2,12 @@ package echo
import ( import (
"bytes" "bytes"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing"
"reflect" "reflect"
"strings" "strings"
"testing"
"errors"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"

View File

@ -92,7 +92,7 @@ func (g *Group) Match(methods []string, path string, handler HandlerFunc, middle
// Group creates a new sub-group with prefix and optional sub-group-level middleware. // Group creates a new sub-group with prefix and optional sub-group-level middleware.
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) *Group { func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) *Group {
m := []MiddlewareFunc{} m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...) m = append(m, g.middleware...)
m = append(m, middleware...) m = append(m, middleware...)
return g.echo.Group(g.prefix+prefix, m...) return g.echo.Group(g.prefix+prefix, m...)
@ -113,7 +113,7 @@ func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...Midd
// Combine into a new slice to avoid accidentally passing the same slice for // Combine into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the // multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls. // middleware from earlier calls.
m := []MiddlewareFunc{} m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...) m = append(m, g.middleware...)
m = append(m, middleware...) m = append(m, middleware...)
return g.echo.Add(method, g.prefix+path, handler, m...) return g.echo.Add(method, g.prefix+path, handler, m...)

View File

@ -93,10 +93,8 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
} }
} }
realm := "" realm := defaultRealm
if config.Realm == defaultRealm { if config.Realm != defaultRealm {
realm = defaultRealm
} else {
realm = strconv.Quote(config.Realm) realm = strconv.Quote(config.Realm)
} }

View File

@ -31,6 +31,19 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, auth) req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c)) assert.NoError(t, h(c))
h = BasicAuthWithConfig(BasicAuthConfig{
Skipper: nil,
Validator: f,
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
// Case-insensitive header scheme // Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth) req.Header.Set(echo.HeaderAuthorization, auth)
@ -41,7 +54,7 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, auth) req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError) he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code) assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate)) assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
// Missing Authorization header // Missing Authorization header
req.Header.Del(echo.HeaderAuthorization) req.Header.Del(echo.HeaderAuthorization)

View File

@ -3,12 +3,11 @@ package middleware
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"io"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"io"
"github.com/labstack/echo" "github.com/labstack/echo"
) )

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -31,10 +32,65 @@ func TestBodyDump(t *testing.T) {
requestBody = string(reqBody) requestBody = string(reqBody)
responseBody = string(resBody) responseBody = string(resBody)
}) })
if assert.NoError(t, mw(h)(c)) { if assert.NoError(t, mw(h)(c)) {
assert.Equal(t, requestBody, hw) assert.Equal(t, requestBody, hw)
assert.Equal(t, responseBody, hw) assert.Equal(t, responseBody, hw)
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.String()) assert.Equal(t, hw, rec.Body.String())
} }
// Must set default skipper
BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
},
})
}
func TestBodyDumpFails(t *testing.T) {
e := echo.New()
hw := "Hello, World!"
req := httptest.NewRequest(echo.POST, "/", strings.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
return errors.New("some error")
}
requestBody := ""
responseBody := ""
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
})
if !assert.Error(t, mw(h)(c)) {
t.FailNow()
}
assert.Panics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: nil,
})
})
assert.NotPanics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{
Skipper: func(c echo.Context) bool {
return true
},
Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
},
})
if !assert.Error(t, mw(h)(c)) {
t.FailNow()
}
})
} }

View File

@ -36,8 +36,8 @@ func TestGzip(t *testing.T) {
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
r, err := gzip.NewReader(rec.Body) r, err := gzip.NewReader(rec.Body)
if assert.NoError(t, err) { if assert.NoError(t, err) {
defer r.Close()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r) buf.ReadFrom(r)
assert.Equal(t, "test", buf.String()) assert.Equal(t, "test", buf.String())
} }

View File

@ -124,10 +124,11 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
req := c.Request() req := c.Request()
k, err := c.Cookie(config.CookieName) k, err := c.Cookie(config.CookieName)
token := ""
if err != nil { var token string
// Generate token // Generate token
if err != nil {
token = random.String(config.TokenLength) token = random.String(config.TokenLength)
} else { } else {
// Reuse token // Reuse token
@ -187,7 +188,7 @@ func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
token := c.FormValue(param) token := c.FormValue(param)
if token == "" { if token == "" {
return "", errors.New("Missing csrf token in the form parameter") return "", errors.New("missing csrf token in the form parameter")
} }
return token, nil return token, nil
} }
@ -199,7 +200,7 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
token := c.QueryParam(param) token := c.QueryParam(param)
if token == "" { if token == "" {
return "", errors.New("Missing csrf token in the query string") return "", errors.New("missing csrf token in the query string")
} }
return token, nil return token, nil
} }

View File

@ -116,7 +116,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
config.keyFunc = func(t *jwt.Token) (interface{}, error) { config.keyFunc = func(t *jwt.Token) (interface{}, error) {
// Check the signing method // Check the signing method
if t.Method.Alg() != config.SigningMethod { if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("Unexpected jwt signing method=%v", t.Header["alg"]) return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
} }
return config.SigningKey, nil return config.SigningKey, nil
} }

View File

@ -114,14 +114,14 @@ func keyFromHeader(header string, authScheme string) keyExtractor {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header) auth := c.Request().Header.Get(header)
if auth == "" { if auth == "" {
return "", errors.New("Missing key in request header") return "", errors.New("missing key in request header")
} }
if header == echo.HeaderAuthorization { if header == echo.HeaderAuthorization {
l := len(authScheme) l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme { if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil return auth[l+1:], nil
} }
return "", errors.New("Invalid key in the request header") return "", errors.New("invalid key in the request header")
} }
return auth, nil return auth, nil
} }
@ -132,7 +132,7 @@ func keyFromQuery(param string) keyExtractor {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
key := c.QueryParam(param) key := c.QueryParam(param)
if key == "" { if key == "" {
return "", errors.New("Missing key in the query string") return "", errors.New("missing key in the query string")
} }
return key, nil return key, nil
} }
@ -143,7 +143,7 @@ func keyFromForm(param string) keyExtractor {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
key := c.FormValue(param) key := c.FormValue(param)
if key == "" { if key == "" {
return "", errors.New("Missing key in the form") return "", errors.New("missing key in the form")
} }
return key, nil return key, nil
} }

View File

@ -2,18 +2,18 @@ package middleware
import ( import (
"bytes" "bytes"
"encoding/json"
"errors" "errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"encoding/json"
"github.com/labstack/echo"
"github.com/stretchr/testify/assert"
"time" "time"
"unsafe" "unsafe"
"github.com/labstack/echo"
"github.com/stretchr/testify/assert"
) )
func TestLogger(t *testing.T) { func TestLogger(t *testing.T) {

View File

@ -108,15 +108,15 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
return return
} }
errc := make(chan error, 2) errCh := make(chan error, 2)
cp := func(dst io.Writer, src io.Reader) { cp := func(dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src) _, err = io.Copy(dst, src)
errc <- err errCh <- err
} }
go cp(out, in) go cp(out, in)
go cp(in, out) go cp(in, out)
err = <-errc err = <-errCh
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
c.Logger().Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err) c.Logger().Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)
} }

View File

@ -4,9 +4,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing"
"net/url" "net/url"
"testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -48,14 +47,25 @@ func TestProxy(t *testing.T) {
url2, _ := url.Parse(t2.URL) url2, _ := url.Parse(t2.URL)
targets := []*ProxyTarget{ targets := []*ProxyTarget{
&ProxyTarget{ {
Name: "target 1",
URL: url1, URL: url1,
}, },
&ProxyTarget{ {
Name: "target 2",
URL: url2, URL: url2,
}, },
} }
rb := NewRandomBalancer(targets) rb := NewRandomBalancer(nil)
// must add targets:
for _, target := range targets {
assert.True(t, rb.AddTarget(target))
}
// must ignore duplicates:
for _, target := range targets {
assert.False(t, rb.AddTarget(target))
}
// Random // Random
e := echo.New() e := echo.New()
@ -72,6 +82,12 @@ func TestProxy(t *testing.T) {
return expected[body] return expected[body]
}) })
for _, target := range targets {
assert.True(t, rb.RemoveTarget(target.Name))
}
assert.False(t, rb.RemoveTarget("unknown target"))
// Round-robin // Round-robin
rrb := NewRoundRobinBalancer(targets) rrb := NewRoundRobinBalancer(targets)
e = echo.New() e = echo.New()

View File

@ -60,7 +60,7 @@ func (r *Router) Add(method, path string, h HandlerFunc) {
path = "/" + path path = "/" + path
} }
ppath := path // Pristine path ppath := path // Pristine path
pnames := []string{} // Param names var pnames []string // Param names
for i, l := 0, len(path); i < l; i++ { for i, l := 0, len(path); i < l; i++ {
if path[i] == ':' { if path[i] == ':' {