1
0
mirror of https://github.com/labstack/echo.git synced 2025-12-01 22:51:17 +02:00
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana
2016-04-02 14:19:39 -07:00
parent be5148ae27
commit b5d6c05101
22 changed files with 266 additions and 300 deletions

View File

@@ -39,8 +39,8 @@ func BasicAuth(f BasicAuthFunc) echo.MiddlewareFunc {
// BasicAuthFromConfig returns an HTTP basic auth middleware from config.
// See `BasicAuth()`.
func BasicAuthFromConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return func(next echo.Handler) echo.Handler {
return echo.HandlerFunc(func(c echo.Context) error {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
auth := c.Request().Header().Get(echo.Authorization)
l := len(basic)
@@ -52,7 +52,7 @@ func BasicAuthFromConfig(config BasicAuthConfig) echo.MiddlewareFunc {
if cred[i] == ':' {
// Verify credentials
if config.AuthFunc(cred[:i], cred[i+1:]) {
return next.Handle(c)
return next(c)
}
}
}
@@ -60,6 +60,6 @@ func BasicAuthFromConfig(config BasicAuthConfig) echo.MiddlewareFunc {
}
c.Response().Header().Set(echo.WWWAuthenticate, basic+" realm=Restricted")
return echo.ErrUnauthorized
})
}
}
}

View File

@@ -21,14 +21,14 @@ func TestBasicAuth(t *testing.T) {
}
return false
}
h := BasicAuth(f)(echo.HandlerFunc(func(c echo.Context) error {
h := BasicAuth(f)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}))
})
// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
rq.Header().Set(echo.Authorization, auth)
assert.NoError(t, h.Handle(c))
assert.NoError(t, h(c))
//---------------------
// Invalid credentials
@@ -37,20 +37,20 @@ func TestBasicAuth(t *testing.T) {
// Incorrect password
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
rq.Header().Set(echo.Authorization, auth)
he := h.Handle(c).(*echo.HTTPError)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", rs.Header().Get(echo.WWWAuthenticate))
// Empty Authorization header
rq.Header().Set(echo.Authorization, "")
he = h.Handle(c).(*echo.HTTPError)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", rs.Header().Get(echo.WWWAuthenticate))
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
rq.Header().Set(echo.Authorization, auth)
he = h.Handle(c).(*echo.HTTPError)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", rs.Header().Get(echo.WWWAuthenticate))
}

View File

@@ -50,8 +50,8 @@ func GzipFromConfig(config GzipConfig) echo.MiddlewareFunc {
pool := gzipPool(config)
scheme := "gzip"
return func(next echo.Handler) echo.Handler {
return echo.HandlerFunc(func(c echo.Context) error {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
rs := c.Response()
rs.Header().Add(echo.Vary, echo.AcceptEncoding)
if strings.Contains(c.Request().Header().Get(echo.AcceptEncoding), scheme) {
@@ -74,8 +74,8 @@ func GzipFromConfig(config GzipConfig) echo.MiddlewareFunc {
rs.Header().Set(echo.ContentEncoding, scheme)
rs.SetWriter(g)
}
return next.Handle(c)
})
return next(c)
}
}
}

View File

@@ -19,11 +19,11 @@ func TestGzip(t *testing.T) {
c := echo.NewContext(rq, rec, e)
// Skip if no Accept-Encoding header
h := Gzip()(echo.HandlerFunc(func(c echo.Context) error {
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
}))
h.Handle(c)
})
h(c)
assert.Equal(t, "test", rec.Body.String())
rq = test.NewRequest(echo.GET, "/", nil)
@@ -32,7 +32,7 @@ func TestGzip(t *testing.T) {
c = echo.NewContext(rq, rec, e)
// Gzip
h.Handle(c)
h(c)
assert.Equal(t, "gzip", rec.Header().Get(echo.ContentEncoding))
assert.Contains(t, rec.Header().Get(echo.ContentType), echo.TextPlain)
r, err := gzip.NewReader(rec.Body)
@@ -49,10 +49,10 @@ func TestGzipNoContent(t *testing.T) {
rq := test.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder()
c := echo.NewContext(rq, rec, e)
h := Gzip()(echo.HandlerFunc(func(c echo.Context) error {
h := Gzip()(func(c echo.Context) error {
return c.NoContent(http.StatusOK)
}))
h.Handle(c)
})
h(c)
assert.Empty(t, rec.Header().Get(echo.ContentEncoding))
assert.Empty(t, rec.Header().Get(echo.ContentType))
@@ -65,9 +65,9 @@ func TestGzipNoContent(t *testing.T) {
func TestGzipErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Gzip())
e.Get("/", echo.HandlerFunc(func(c echo.Context) error {
e.Get("/", func(c echo.Context) error {
return echo.NewHTTPError(http.StatusInternalServerError, "error")
}))
})
rq := test.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder()
e.ServeHTTP(rq, rec)

View File

@@ -74,12 +74,12 @@ func LoggerFromConfig(config LoggerConfig) echo.MiddlewareFunc {
config.color.Disable()
}
return func(next echo.Handler) echo.Handler {
return echo.HandlerFunc(func(c echo.Context) (err error) {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
rq := c.Request()
rs := c.Response()
start := time.Now()
if err = next.Handle(c); err != nil {
if err = next(c); err != nil {
c.Error(err)
}
stop := time.Now()
@@ -129,6 +129,6 @@ func LoggerFromConfig(config LoggerConfig) echo.MiddlewareFunc {
}
})
return
})
}
}
}

View File

@@ -17,37 +17,37 @@ func TestLogger(t *testing.T) {
rq := test.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder()
c := echo.NewContext(rq, rec, e)
h := Logger()(echo.HandlerFunc(func(c echo.Context) error {
h := Logger()(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}))
})
// Status 2xx
h.Handle(c)
h(c)
// Status 3xx
rec = test.NewResponseRecorder()
c = echo.NewContext(rq, rec, e)
h = Logger()(echo.HandlerFunc(func(c echo.Context) error {
h = Logger()(func(c echo.Context) error {
return c.String(http.StatusTemporaryRedirect, "test")
}))
h.Handle(c)
})
h(c)
// Status 4xx
rec = test.NewResponseRecorder()
c = echo.NewContext(rq, rec, e)
h = Logger()(echo.HandlerFunc(func(c echo.Context) error {
h = Logger()(func(c echo.Context) error {
return c.String(http.StatusNotFound, "test")
}))
h.Handle(c)
})
h(c)
// Status 5xx with empty path
rq = test.NewRequest(echo.GET, "", nil)
rec = test.NewResponseRecorder()
c = echo.NewContext(rq, rec, e)
h = Logger()(echo.HandlerFunc(func(c echo.Context) error {
h = Logger()(func(c echo.Context) error {
return errors.New("error")
}))
h.Handle(c)
})
h(c)
}
func TestLoggerIPAddress(t *testing.T) {
@@ -58,24 +58,24 @@ func TestLoggerIPAddress(t *testing.T) {
buf := new(bytes.Buffer)
e.Logger().SetOutput(buf)
ip := "127.0.0.1"
h := Logger()(echo.HandlerFunc(func(c echo.Context) error {
h := Logger()(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}))
})
// With X-Real-IP
rq.Header().Add(echo.XRealIP, ip)
h.Handle(c)
h(c)
assert.Contains(t, ip, buf.String())
// With X-Forwarded-For
buf.Reset()
rq.Header().Del(echo.XRealIP)
rq.Header().Add(echo.XForwardedFor, ip)
h.Handle(c)
h(c)
assert.Contains(t, ip, buf.String())
// with rq.RemoteAddr
buf.Reset()
h.Handle(c)
h(c)
assert.Contains(t, ip, buf.String())
}

View File

@@ -49,8 +49,8 @@ func RecoverFromConfig(config RecoverConfig) echo.MiddlewareFunc {
config.StackSize = DefaultRecoverConfig.StackSize
}
return func(next echo.Handler) echo.Handler {
return echo.HandlerFunc(func(c echo.Context) error {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
defer func() {
if r := recover(); r != nil {
var err error
@@ -68,7 +68,7 @@ func RecoverFromConfig(config RecoverConfig) echo.MiddlewareFunc {
c.Error(err)
}
}()
return next.Handle(c)
})
return next(c)
}
}
}

View File

@@ -20,7 +20,7 @@ func TestRecover(t *testing.T) {
h := Recover()(echo.HandlerFunc(func(c echo.Context) error {
panic("test")
}))
h.Handle(c)
h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Status())
assert.Contains(t, buf.String(), "PANIC RECOVER")
}

View File

@@ -9,15 +9,15 @@ import (
//
// Usage `Echo#Pre(AddTrailingSlash())`
func AddTrailingSlash() echo.MiddlewareFunc {
return func(next echo.Handler) echo.Handler {
return echo.HandlerFunc(func(c echo.Context) error {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
url := c.Request().URL()
path := url.Path()
if path != "/" && path[len(path)-1] != '/' {
url.SetPath(path + "/")
}
return next.Handle(c)
})
return next(c)
}
}
}
@@ -26,15 +26,15 @@ func AddTrailingSlash() echo.MiddlewareFunc {
//
// Usage `Echo#Pre(RemoveTrailingSlash())`
func RemoveTrailingSlash() echo.MiddlewareFunc {
return func(next echo.Handler) echo.Handler {
return echo.HandlerFunc(func(c echo.Context) error {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
url := c.Request().URL()
path := url.Path()
l := len(path) - 1
if path != "/" && path[l] == '/' {
url.SetPath(path[:l])
}
return next.Handle(c)
})
return next(c)
}
}
}

View File

@@ -13,10 +13,10 @@ func TestAddTrailingSlash(t *testing.T) {
rq := test.NewRequest(echo.GET, "/add-slash", nil)
rc := test.NewResponseRecorder()
c := echo.NewContext(rq, rc, e)
h := AddTrailingSlash()(echo.HandlerFunc(func(c echo.Context) error {
h := AddTrailingSlash()(func(c echo.Context) error {
return nil
}))
h.Handle(c)
})
h(c)
assert.Equal(t, "/add-slash/", rq.URL().Path())
}
@@ -25,9 +25,9 @@ func TestRemoveTrailingSlash(t *testing.T) {
rq := test.NewRequest(echo.GET, "/remove-slash/", nil)
rc := test.NewResponseRecorder()
c := echo.NewContext(rq, rc, e)
h := RemoveTrailingSlash()(echo.HandlerFunc(func(c echo.Context) error {
h := RemoveTrailingSlash()(func(c echo.Context) error {
return nil
}))
h.Handle(c)
})
h(c)
assert.Equal(t, "/remove-slash", rq.URL().Path())
}

View File

@@ -51,8 +51,8 @@ func StaticFromConfig(config StaticConfig) echo.MiddlewareFunc {
config.Index = DefaultStaticConfig.Index
}
return func(next echo.Handler) echo.Handler {
return echo.HandlerFunc(func(c echo.Context) error {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
fs := http.Dir(config.Root)
p := c.Request().URL().Path()
if c.P(0) != "" { // If serving from `Group`, e.g. `/static/*`
@@ -61,7 +61,7 @@ func StaticFromConfig(config StaticConfig) echo.MiddlewareFunc {
file := path.Clean(p)
f, err := fs.Open(file)
if err != nil {
return next.Handle(c)
return next(c)
}
defer f.Close()
@@ -108,11 +108,11 @@ func StaticFromConfig(config StaticConfig) echo.MiddlewareFunc {
_, err = fmt.Fprintf(rs, "</pre>\n")
return err
}
return next.Handle(c)
return next(c)
}
fi, _ = f.Stat() // Index file stat
}
return c.ServeContent(f, fi.Name(), fi.ModTime())
})
}
}
}