1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

[FIX] Cleanup code (#1061)

Code cleanup
This commit is contained in:
Evgeniy Kulikov 2018-02-21 21:44:17 +03:00 committed by Vishal Rana
parent 90d675fa2a
commit f49d166e6f
19 changed files with 141 additions and 65 deletions

View File

@ -80,7 +80,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
val := reflect.ValueOf(ptr).Elem()
if typ.Kind() != reflect.Struct {
return errors.New("Binding element must be a struct")
return errors.New("binding element must be a struct")
}
for i := 0; i < typ.NumField(); i++ {

View File

@ -135,8 +135,7 @@ func TestBindForm(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(HeaderContentType, MIMEApplicationForm)
obj := []struct{ Field string }{}
err := c.Bind(&obj)
err := c.Bind(&[]struct{ Field string }{})
assert.Error(t, err)
}

View File

@ -2,21 +2,18 @@ package echo
import (
"bytes"
"encoding/xml"
"errors"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"text/template"
"time"
"strings"
"net/url"
"encoding/xml"
"github.com/stretchr/testify/assert"
)
@ -217,7 +214,7 @@ func TestContext(t *testing.T) {
c.SetParamNames("foo")
c.SetParamValues("bar")
c.Set("foe", "ban")
c.query = url.Values(map[string][]string{"fon": []string{"baz"}})
c.query = url.Values(map[string][]string{"fon": {"baz"}})
c.Reset(req, httptest.NewRecorder())
assert.Equal(t, 0, len(c.ParamValues()))
assert.Equal(t, 0, len(c.ParamNames()))

18
echo.go
View File

@ -251,10 +251,10 @@ var (
ErrForbidden = NewHTTPError(http.StatusForbidden)
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
ErrValidatorNotRegistered = errors.New("Validator not registered")
ErrRendererNotRegistered = errors.New("Renderer not registered")
ErrInvalidRedirectCode = errors.New("Invalid redirect status code")
ErrCookieNotFound = errors.New("Cookie not found")
ErrValidatorNotRegistered = errors.New("validator not registered")
ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
)
// Error handlers
@ -530,7 +530,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
// Routes returns the registered routes.
func (e *Echo) Routes() []*Route {
routes := []*Route{}
routes := make([]*Route, 0, len(e.router.routes))
for _, v := range e.router.routes {
routes = append(routes, v)
}
@ -563,11 +563,11 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Middleware
h := func(c Context) error {
method := r.Method
path := r.URL.RawPath
if path == "" {
path = r.URL.Path
rpath := r.URL.RawPath // raw path
if rpath == "" {
rpath = r.URL.Path
}
e.router.Find(method, path, c)
e.router.Find(method, rpath, c)
h := c.Handler()
for i := len(e.middleware) - 1; i >= 0; i-- {
h = e.middleware[i](h)

View File

@ -2,15 +2,12 @@ package echo
import (
"bytes"
"errors"
"net/http"
"net/http/httptest"
"testing"
"reflect"
"strings"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"

View File

@ -92,7 +92,7 @@ func (g *Group) Match(methods []string, path string, handler HandlerFunc, middle
// Group creates a new sub-group with prefix and optional sub-group-level middleware.
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) *Group {
m := []MiddlewareFunc{}
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
m = append(m, middleware...)
return g.echo.Group(g.prefix+prefix, m...)
@ -113,7 +113,7 @@ func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...Midd
// Combine into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls.
m := []MiddlewareFunc{}
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
m = append(m, middleware...)
return g.echo.Add(method, g.prefix+path, handler, m...)

View File

@ -93,10 +93,8 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
}
}
realm := ""
if config.Realm == defaultRealm {
realm = defaultRealm
} else {
realm := defaultRealm
if config.Realm != defaultRealm {
realm = strconv.Quote(config.Realm)
}

View File

@ -31,6 +31,19 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
h = BasicAuthWithConfig(BasicAuthConfig{
Skipper: nil,
Validator: f,
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
@ -41,7 +54,7 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)

View File

@ -3,12 +3,11 @@ package middleware
import (
"bufio"
"bytes"
"io"
"io/ioutil"
"net"
"net/http"
"io"
"github.com/labstack/echo"
)

View File

@ -1,6 +1,7 @@
package middleware
import (
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -31,10 +32,65 @@ func TestBodyDump(t *testing.T) {
requestBody = string(reqBody)
responseBody = string(resBody)
})
if assert.NoError(t, mw(h)(c)) {
assert.Equal(t, requestBody, hw)
assert.Equal(t, responseBody, hw)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.String())
}
// Must set default skipper
BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
},
})
}
func TestBodyDumpFails(t *testing.T) {
e := echo.New()
hw := "Hello, World!"
req := httptest.NewRequest(echo.POST, "/", strings.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
return errors.New("some error")
}
requestBody := ""
responseBody := ""
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
})
if !assert.Error(t, mw(h)(c)) {
t.FailNow()
}
assert.Panics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: nil,
})
})
assert.NotPanics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{
Skipper: func(c echo.Context) bool {
return true
},
Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
},
})
if !assert.Error(t, mw(h)(c)) {
t.FailNow()
}
})
}

View File

@ -36,8 +36,8 @@ func TestGzip(t *testing.T) {
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
r, err := gzip.NewReader(rec.Body)
if assert.NoError(t, err) {
defer r.Close()
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal(t, "test", buf.String())
}

View File

@ -124,10 +124,11 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
req := c.Request()
k, err := c.Cookie(config.CookieName)
token := ""
var token string
// Generate token
if err != nil {
// Generate token
token = random.String(config.TokenLength)
} else {
// Reuse token
@ -187,7 +188,7 @@ func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errors.New("Missing csrf token in the form parameter")
return "", errors.New("missing csrf token in the form parameter")
}
return token, nil
}
@ -199,7 +200,7 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", errors.New("Missing csrf token in the query string")
return "", errors.New("missing csrf token in the query string")
}
return token, nil
}

View File

@ -116,7 +116,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
config.keyFunc = func(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("Unexpected jwt signing method=%v", t.Header["alg"])
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
return config.SigningKey, nil
}

View File

@ -114,14 +114,14 @@ func keyFromHeader(header string, authScheme string) keyExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
if auth == "" {
return "", errors.New("Missing key in request header")
return "", errors.New("missing key in request header")
}
if header == echo.HeaderAuthorization {
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", errors.New("Invalid key in the request header")
return "", errors.New("invalid key in the request header")
}
return auth, nil
}
@ -132,7 +132,7 @@ func keyFromQuery(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.QueryParam(param)
if key == "" {
return "", errors.New("Missing key in the query string")
return "", errors.New("missing key in the query string")
}
return key, nil
}
@ -143,7 +143,7 @@ func keyFromForm(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.FormValue(param)
if key == "" {
return "", errors.New("Missing key in the form")
return "", errors.New("missing key in the form")
}
return key, nil
}

View File

@ -47,7 +47,7 @@ type (
// Example "${remote_ip} ${status}"
//
// Optional. Default value DefaultLoggerConfig.Format.
Format string `yaml:"format"`
Format string `yaml:"format"`
// Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
CustomTimeFormat string `yaml:"custom_time_format"`
@ -70,9 +70,9 @@ var (
`"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` +
`"latency_human":"${latency_human}","bytes_in":${bytes_in},` +
`"bytes_out":${bytes_out}}` + "\n",
CustomTimeFormat:"2006-01-02 15:04:05.00000",
Output: os.Stdout,
colorer: color.New(),
CustomTimeFormat: "2006-01-02 15:04:05.00000",
Output: os.Stdout,
colorer: color.New(),
}
)

View File

@ -2,18 +2,18 @@ package middleware
import (
"bytes"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"encoding/json"
"github.com/labstack/echo"
"github.com/stretchr/testify/assert"
"time"
"unsafe"
"github.com/labstack/echo"
"github.com/stretchr/testify/assert"
)
func TestLogger(t *testing.T) {
@ -152,7 +152,7 @@ func TestLoggerCustomTimestamp(t *testing.T) {
`"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}",` +
`"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n",
CustomTimeFormat: customTimeFormat,
Output: buf,
Output: buf,
}))
e.GET("/", func(c echo.Context) error {

View File

@ -108,15 +108,15 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
return
}
errc := make(chan error, 2)
errCh := make(chan error, 2)
cp := func(dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
errc <- err
_, err = io.Copy(dst, src)
errCh <- err
}
go cp(out, in)
go cp(in, out)
err = <-errc
err = <-errCh
if err != nil && err != io.EOF {
c.Logger().Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)
}

View File

@ -4,9 +4,8 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"net/url"
"testing"
"github.com/labstack/echo"
"github.com/stretchr/testify/assert"
@ -48,14 +47,25 @@ func TestProxy(t *testing.T) {
url2, _ := url.Parse(t2.URL)
targets := []*ProxyTarget{
&ProxyTarget{
URL: url1,
{
Name: "target 1",
URL: url1,
},
&ProxyTarget{
URL: url2,
{
Name: "target 2",
URL: url2,
},
}
rb := NewRandomBalancer(targets)
rb := NewRandomBalancer(nil)
// must add targets:
for _, target := range targets {
assert.True(t, rb.AddTarget(target))
}
// must ignore duplicates:
for _, target := range targets {
assert.False(t, rb.AddTarget(target))
}
// Random
e := echo.New()
@ -72,6 +82,12 @@ func TestProxy(t *testing.T) {
return expected[body]
})
for _, target := range targets {
assert.True(t, rb.RemoveTarget(target.Name))
}
assert.False(t, rb.RemoveTarget("unknown target"))
// Round-robin
rrb := NewRoundRobinBalancer(targets)
e = echo.New()

View File

@ -59,8 +59,8 @@ func (r *Router) Add(method, path string, h HandlerFunc) {
if path[0] != '/' {
path = "/" + path
}
ppath := path // Pristine path
pnames := []string{} // Param names
ppath := path // Pristine path
var pnames []string // Param names
for i, l := 0, len(path); i < l; i++ {
if path[i] == ':' {