mirror of
https://github.com/labstack/echo.git
synced 2025-05-13 22:06:36 +02:00
parent
90d675fa2a
commit
f49d166e6f
2
bind.go
2
bind.go
@ -80,7 +80,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
|
|||||||
val := reflect.ValueOf(ptr).Elem()
|
val := reflect.ValueOf(ptr).Elem()
|
||||||
|
|
||||||
if typ.Kind() != reflect.Struct {
|
if typ.Kind() != reflect.Struct {
|
||||||
return errors.New("Binding element must be a struct")
|
return errors.New("binding element must be a struct")
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
@ -135,8 +135,7 @@ func TestBindForm(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
c := e.NewContext(req, rec)
|
c := e.NewContext(req, rec)
|
||||||
req.Header.Set(HeaderContentType, MIMEApplicationForm)
|
req.Header.Set(HeaderContentType, MIMEApplicationForm)
|
||||||
obj := []struct{ Field string }{}
|
err := c.Bind(&[]struct{ Field string }{})
|
||||||
err := c.Bind(&obj)
|
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,21 +2,18 @@ package echo
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/xml"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"encoding/xml"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -217,7 +214,7 @@ func TestContext(t *testing.T) {
|
|||||||
c.SetParamNames("foo")
|
c.SetParamNames("foo")
|
||||||
c.SetParamValues("bar")
|
c.SetParamValues("bar")
|
||||||
c.Set("foe", "ban")
|
c.Set("foe", "ban")
|
||||||
c.query = url.Values(map[string][]string{"fon": []string{"baz"}})
|
c.query = url.Values(map[string][]string{"fon": {"baz"}})
|
||||||
c.Reset(req, httptest.NewRecorder())
|
c.Reset(req, httptest.NewRecorder())
|
||||||
assert.Equal(t, 0, len(c.ParamValues()))
|
assert.Equal(t, 0, len(c.ParamValues()))
|
||||||
assert.Equal(t, 0, len(c.ParamNames()))
|
assert.Equal(t, 0, len(c.ParamNames()))
|
||||||
|
18
echo.go
18
echo.go
@ -251,10 +251,10 @@ var (
|
|||||||
ErrForbidden = NewHTTPError(http.StatusForbidden)
|
ErrForbidden = NewHTTPError(http.StatusForbidden)
|
||||||
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
|
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
|
||||||
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
|
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
|
||||||
ErrValidatorNotRegistered = errors.New("Validator not registered")
|
ErrValidatorNotRegistered = errors.New("validator not registered")
|
||||||
ErrRendererNotRegistered = errors.New("Renderer not registered")
|
ErrRendererNotRegistered = errors.New("renderer not registered")
|
||||||
ErrInvalidRedirectCode = errors.New("Invalid redirect status code")
|
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
|
||||||
ErrCookieNotFound = errors.New("Cookie not found")
|
ErrCookieNotFound = errors.New("cookie not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Error handlers
|
// Error handlers
|
||||||
@ -530,7 +530,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string {
|
|||||||
|
|
||||||
// Routes returns the registered routes.
|
// Routes returns the registered routes.
|
||||||
func (e *Echo) Routes() []*Route {
|
func (e *Echo) Routes() []*Route {
|
||||||
routes := []*Route{}
|
routes := make([]*Route, 0, len(e.router.routes))
|
||||||
for _, v := range e.router.routes {
|
for _, v := range e.router.routes {
|
||||||
routes = append(routes, v)
|
routes = append(routes, v)
|
||||||
}
|
}
|
||||||
@ -563,11 +563,11 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Middleware
|
// Middleware
|
||||||
h := func(c Context) error {
|
h := func(c Context) error {
|
||||||
method := r.Method
|
method := r.Method
|
||||||
path := r.URL.RawPath
|
rpath := r.URL.RawPath // raw path
|
||||||
if path == "" {
|
if rpath == "" {
|
||||||
path = r.URL.Path
|
rpath = r.URL.Path
|
||||||
}
|
}
|
||||||
e.router.Find(method, path, c)
|
e.router.Find(method, rpath, c)
|
||||||
h := c.Handler()
|
h := c.Handler()
|
||||||
for i := len(e.middleware) - 1; i >= 0; i-- {
|
for i := len(e.middleware) - 1; i >= 0; i-- {
|
||||||
h = e.middleware[i](h)
|
h = e.middleware[i](h)
|
||||||
|
@ -2,15 +2,12 @@ package echo
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
|
||||||
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"testing"
|
||||||
"errors"
|
|
||||||
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
4
group.go
4
group.go
@ -92,7 +92,7 @@ func (g *Group) Match(methods []string, path string, handler HandlerFunc, middle
|
|||||||
|
|
||||||
// Group creates a new sub-group with prefix and optional sub-group-level middleware.
|
// Group creates a new sub-group with prefix and optional sub-group-level middleware.
|
||||||
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) *Group {
|
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) *Group {
|
||||||
m := []MiddlewareFunc{}
|
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
|
||||||
m = append(m, g.middleware...)
|
m = append(m, g.middleware...)
|
||||||
m = append(m, middleware...)
|
m = append(m, middleware...)
|
||||||
return g.echo.Group(g.prefix+prefix, m...)
|
return g.echo.Group(g.prefix+prefix, m...)
|
||||||
@ -113,7 +113,7 @@ func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...Midd
|
|||||||
// Combine into a new slice to avoid accidentally passing the same slice for
|
// Combine into a new slice to avoid accidentally passing the same slice for
|
||||||
// multiple routes, which would lead to later add() calls overwriting the
|
// multiple routes, which would lead to later add() calls overwriting the
|
||||||
// middleware from earlier calls.
|
// middleware from earlier calls.
|
||||||
m := []MiddlewareFunc{}
|
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
|
||||||
m = append(m, g.middleware...)
|
m = append(m, g.middleware...)
|
||||||
m = append(m, middleware...)
|
m = append(m, middleware...)
|
||||||
return g.echo.Add(method, g.prefix+path, handler, m...)
|
return g.echo.Add(method, g.prefix+path, handler, m...)
|
||||||
|
@ -93,10 +93,8 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
realm := ""
|
realm := defaultRealm
|
||||||
if config.Realm == defaultRealm {
|
if config.Realm != defaultRealm {
|
||||||
realm = defaultRealm
|
|
||||||
} else {
|
|
||||||
realm = strconv.Quote(config.Realm)
|
realm = strconv.Quote(config.Realm)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,6 +31,19 @@ func TestBasicAuth(t *testing.T) {
|
|||||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||||
assert.NoError(t, h(c))
|
assert.NoError(t, h(c))
|
||||||
|
|
||||||
|
h = BasicAuthWithConfig(BasicAuthConfig{
|
||||||
|
Skipper: nil,
|
||||||
|
Validator: f,
|
||||||
|
Realm: "someRealm",
|
||||||
|
})(func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Valid credentials
|
||||||
|
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||||
|
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||||
|
assert.NoError(t, h(c))
|
||||||
|
|
||||||
// Case-insensitive header scheme
|
// Case-insensitive header scheme
|
||||||
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||||
@ -41,7 +54,7 @@ func TestBasicAuth(t *testing.T) {
|
|||||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||||
he := h(c).(*echo.HTTPError)
|
he := h(c).(*echo.HTTPError)
|
||||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
|
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||||
|
|
||||||
// Missing Authorization header
|
// Missing Authorization header
|
||||||
req.Header.Del(echo.HeaderAuthorization)
|
req.Header.Del(echo.HeaderAuthorization)
|
||||||
|
@ -3,12 +3,11 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/labstack/echo"
|
"github.com/labstack/echo"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -31,10 +32,65 @@ func TestBodyDump(t *testing.T) {
|
|||||||
requestBody = string(reqBody)
|
requestBody = string(reqBody)
|
||||||
responseBody = string(resBody)
|
responseBody = string(resBody)
|
||||||
})
|
})
|
||||||
|
|
||||||
if assert.NoError(t, mw(h)(c)) {
|
if assert.NoError(t, mw(h)(c)) {
|
||||||
assert.Equal(t, requestBody, hw)
|
assert.Equal(t, requestBody, hw)
|
||||||
assert.Equal(t, responseBody, hw)
|
assert.Equal(t, responseBody, hw)
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, hw, rec.Body.String())
|
assert.Equal(t, hw, rec.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Must set default skipper
|
||||||
|
BodyDumpWithConfig(BodyDumpConfig{
|
||||||
|
Skipper: nil,
|
||||||
|
Handler: func(c echo.Context, reqBody, resBody []byte) {
|
||||||
|
requestBody = string(reqBody)
|
||||||
|
responseBody = string(resBody)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBodyDumpFails(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
hw := "Hello, World!"
|
||||||
|
req := httptest.NewRequest(echo.POST, "/", strings.NewReader(hw))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
h := func(c echo.Context) error {
|
||||||
|
return errors.New("some error")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestBody := ""
|
||||||
|
responseBody := ""
|
||||||
|
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
|
||||||
|
requestBody = string(reqBody)
|
||||||
|
responseBody = string(resBody)
|
||||||
|
})
|
||||||
|
|
||||||
|
if !assert.Error(t, mw(h)(c)) {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
mw = BodyDumpWithConfig(BodyDumpConfig{
|
||||||
|
Skipper: nil,
|
||||||
|
Handler: nil,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
mw = BodyDumpWithConfig(BodyDumpConfig{
|
||||||
|
Skipper: func(c echo.Context) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
Handler: func(c echo.Context, reqBody, resBody []byte) {
|
||||||
|
requestBody = string(reqBody)
|
||||||
|
responseBody = string(resBody)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if !assert.Error(t, mw(h)(c)) {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@ -36,8 +36,8 @@ func TestGzip(t *testing.T) {
|
|||||||
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
|
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
|
||||||
r, err := gzip.NewReader(rec.Body)
|
r, err := gzip.NewReader(rec.Body)
|
||||||
if assert.NoError(t, err) {
|
if assert.NoError(t, err) {
|
||||||
defer r.Close()
|
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
|
defer r.Close()
|
||||||
buf.ReadFrom(r)
|
buf.ReadFrom(r)
|
||||||
assert.Equal(t, "test", buf.String())
|
assert.Equal(t, "test", buf.String())
|
||||||
}
|
}
|
||||||
|
@ -124,10 +124,11 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|||||||
|
|
||||||
req := c.Request()
|
req := c.Request()
|
||||||
k, err := c.Cookie(config.CookieName)
|
k, err := c.Cookie(config.CookieName)
|
||||||
token := ""
|
|
||||||
|
|
||||||
if err != nil {
|
var token string
|
||||||
|
|
||||||
// Generate token
|
// Generate token
|
||||||
|
if err != nil {
|
||||||
token = random.String(config.TokenLength)
|
token = random.String(config.TokenLength)
|
||||||
} else {
|
} else {
|
||||||
// Reuse token
|
// Reuse token
|
||||||
@ -187,7 +188,7 @@ func csrfTokenFromForm(param string) csrfTokenExtractor {
|
|||||||
return func(c echo.Context) (string, error) {
|
return func(c echo.Context) (string, error) {
|
||||||
token := c.FormValue(param)
|
token := c.FormValue(param)
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return "", errors.New("Missing csrf token in the form parameter")
|
return "", errors.New("missing csrf token in the form parameter")
|
||||||
}
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
@ -199,7 +200,7 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor {
|
|||||||
return func(c echo.Context) (string, error) {
|
return func(c echo.Context) (string, error) {
|
||||||
token := c.QueryParam(param)
|
token := c.QueryParam(param)
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return "", errors.New("Missing csrf token in the query string")
|
return "", errors.New("missing csrf token in the query string")
|
||||||
}
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
@ -116,7 +116,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
|||||||
config.keyFunc = func(t *jwt.Token) (interface{}, error) {
|
config.keyFunc = func(t *jwt.Token) (interface{}, error) {
|
||||||
// Check the signing method
|
// Check the signing method
|
||||||
if t.Method.Alg() != config.SigningMethod {
|
if t.Method.Alg() != config.SigningMethod {
|
||||||
return nil, fmt.Errorf("Unexpected jwt signing method=%v", t.Header["alg"])
|
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
|
||||||
}
|
}
|
||||||
return config.SigningKey, nil
|
return config.SigningKey, nil
|
||||||
}
|
}
|
||||||
|
@ -114,14 +114,14 @@ func keyFromHeader(header string, authScheme string) keyExtractor {
|
|||||||
return func(c echo.Context) (string, error) {
|
return func(c echo.Context) (string, error) {
|
||||||
auth := c.Request().Header.Get(header)
|
auth := c.Request().Header.Get(header)
|
||||||
if auth == "" {
|
if auth == "" {
|
||||||
return "", errors.New("Missing key in request header")
|
return "", errors.New("missing key in request header")
|
||||||
}
|
}
|
||||||
if header == echo.HeaderAuthorization {
|
if header == echo.HeaderAuthorization {
|
||||||
l := len(authScheme)
|
l := len(authScheme)
|
||||||
if len(auth) > l+1 && auth[:l] == authScheme {
|
if len(auth) > l+1 && auth[:l] == authScheme {
|
||||||
return auth[l+1:], nil
|
return auth[l+1:], nil
|
||||||
}
|
}
|
||||||
return "", errors.New("Invalid key in the request header")
|
return "", errors.New("invalid key in the request header")
|
||||||
}
|
}
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
@ -132,7 +132,7 @@ func keyFromQuery(param string) keyExtractor {
|
|||||||
return func(c echo.Context) (string, error) {
|
return func(c echo.Context) (string, error) {
|
||||||
key := c.QueryParam(param)
|
key := c.QueryParam(param)
|
||||||
if key == "" {
|
if key == "" {
|
||||||
return "", errors.New("Missing key in the query string")
|
return "", errors.New("missing key in the query string")
|
||||||
}
|
}
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
@ -143,7 +143,7 @@ func keyFromForm(param string) keyExtractor {
|
|||||||
return func(c echo.Context) (string, error) {
|
return func(c echo.Context) (string, error) {
|
||||||
key := c.FormValue(param)
|
key := c.FormValue(param)
|
||||||
if key == "" {
|
if key == "" {
|
||||||
return "", errors.New("Missing key in the form")
|
return "", errors.New("missing key in the form")
|
||||||
}
|
}
|
||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
@ -2,18 +2,18 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/labstack/echo"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/labstack/echo"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLogger(t *testing.T) {
|
func TestLogger(t *testing.T) {
|
||||||
|
@ -108,15 +108,15 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
errc := make(chan error, 2)
|
errCh := make(chan error, 2)
|
||||||
cp := func(dst io.Writer, src io.Reader) {
|
cp := func(dst io.Writer, src io.Reader) {
|
||||||
_, err := io.Copy(dst, src)
|
_, err = io.Copy(dst, src)
|
||||||
errc <- err
|
errCh <- err
|
||||||
}
|
}
|
||||||
|
|
||||||
go cp(out, in)
|
go cp(out, in)
|
||||||
go cp(in, out)
|
go cp(in, out)
|
||||||
err = <-errc
|
err = <-errCh
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
c.Logger().Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)
|
c.Logger().Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)
|
||||||
}
|
}
|
||||||
|
@ -4,9 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
|
||||||
|
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/labstack/echo"
|
"github.com/labstack/echo"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -48,14 +47,25 @@ func TestProxy(t *testing.T) {
|
|||||||
url2, _ := url.Parse(t2.URL)
|
url2, _ := url.Parse(t2.URL)
|
||||||
|
|
||||||
targets := []*ProxyTarget{
|
targets := []*ProxyTarget{
|
||||||
&ProxyTarget{
|
{
|
||||||
|
Name: "target 1",
|
||||||
URL: url1,
|
URL: url1,
|
||||||
},
|
},
|
||||||
&ProxyTarget{
|
{
|
||||||
|
Name: "target 2",
|
||||||
URL: url2,
|
URL: url2,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
rb := NewRandomBalancer(targets)
|
rb := NewRandomBalancer(nil)
|
||||||
|
// must add targets:
|
||||||
|
for _, target := range targets {
|
||||||
|
assert.True(t, rb.AddTarget(target))
|
||||||
|
}
|
||||||
|
|
||||||
|
// must ignore duplicates:
|
||||||
|
for _, target := range targets {
|
||||||
|
assert.False(t, rb.AddTarget(target))
|
||||||
|
}
|
||||||
|
|
||||||
// Random
|
// Random
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
@ -72,6 +82,12 @@ func TestProxy(t *testing.T) {
|
|||||||
return expected[body]
|
return expected[body]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
for _, target := range targets {
|
||||||
|
assert.True(t, rb.RemoveTarget(target.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.False(t, rb.RemoveTarget("unknown target"))
|
||||||
|
|
||||||
// Round-robin
|
// Round-robin
|
||||||
rrb := NewRoundRobinBalancer(targets)
|
rrb := NewRoundRobinBalancer(targets)
|
||||||
e = echo.New()
|
e = echo.New()
|
||||||
|
@ -60,7 +60,7 @@ func (r *Router) Add(method, path string, h HandlerFunc) {
|
|||||||
path = "/" + path
|
path = "/" + path
|
||||||
}
|
}
|
||||||
ppath := path // Pristine path
|
ppath := path // Pristine path
|
||||||
pnames := []string{} // Param names
|
var pnames []string // Param names
|
||||||
|
|
||||||
for i, l := 0, len(path); i < l; i++ {
|
for i, l := 0, len(path); i < l; i++ {
|
||||||
if path[i] == ':' {
|
if path[i] == ':' {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user