mirror of
https://github.com/labstack/echo.git
synced 2025-07-15 01:34:53 +02:00
V5.0.0-alpha
This commit is contained in:
@ -2,70 +2,157 @@ package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
f := func(u, p string, c echo.Context) (bool, error) {
|
||||
validatorFunc := func(c echo.Context, u, p string) (bool, error) {
|
||||
if u == "joe" && p == "secret" {
|
||||
return true, nil
|
||||
}
|
||||
if u == "error" {
|
||||
return false, errors.New(p)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
h := BasicAuth(f)(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
defaultConfig := BasicAuthConfig{Validator: validatorFunc}
|
||||
|
||||
assert := assert.New(t)
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenConfig BasicAuthConfig
|
||||
whenAuth []string
|
||||
expectHeader string
|
||||
expectErr string
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
givenConfig: defaultConfig,
|
||||
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
|
||||
},
|
||||
{
|
||||
name: "ok, multiple",
|
||||
givenConfig: defaultConfig,
|
||||
whenAuth: []string{
|
||||
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
|
||||
basic + " NOT_BASE64",
|
||||
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nok, invalid Authorization header",
|
||||
givenConfig: defaultConfig,
|
||||
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
|
||||
expectHeader: basic + ` realm=Restricted`,
|
||||
expectErr: "code=401, message=Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "nok, not base64 Authorization header",
|
||||
givenConfig: defaultConfig,
|
||||
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
|
||||
expectErr: "invalid basic auth value: illegal base64 data at input byte 3",
|
||||
},
|
||||
{
|
||||
name: "nok, missing Authorization header",
|
||||
givenConfig: defaultConfig,
|
||||
expectHeader: basic + ` realm=Restricted`,
|
||||
expectErr: "code=401, message=Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "ok, realm",
|
||||
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
|
||||
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
|
||||
},
|
||||
{
|
||||
name: "ok, realm, case-insensitive header scheme",
|
||||
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
|
||||
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
|
||||
},
|
||||
{
|
||||
name: "nok, realm, invalid Authorization header",
|
||||
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
|
||||
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
|
||||
expectHeader: basic + ` realm="someRealm"`,
|
||||
expectErr: "code=401, message=Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "nok, validator func returns an error",
|
||||
givenConfig: defaultConfig,
|
||||
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
|
||||
expectErr: "my_error",
|
||||
},
|
||||
{
|
||||
name: "ok, skipped",
|
||||
givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool {
|
||||
return true
|
||||
}},
|
||||
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
|
||||
},
|
||||
}
|
||||
|
||||
// Valid credentials
|
||||
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(h(c))
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
|
||||
h = BasicAuthWithConfig(BasicAuthConfig{
|
||||
Skipper: nil,
|
||||
Validator: f,
|
||||
Realm: "someRealm",
|
||||
})(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
config := tc.givenConfig
|
||||
|
||||
// Valid credentials
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(h(c))
|
||||
mw, err := config.ToMiddleware()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Case-insensitive header scheme
|
||||
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(h(c))
|
||||
h := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusTeapot, "test")
|
||||
})
|
||||
|
||||
// Invalid credentials
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(http.StatusUnauthorized, he.Code)
|
||||
assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
if len(tc.whenAuth) != 0 {
|
||||
for _, a := range tc.whenAuth {
|
||||
req.Header.Add(echo.HeaderAuthorization, a)
|
||||
}
|
||||
}
|
||||
err = h(c)
|
||||
|
||||
// Missing Authorization header
|
||||
req.Header.Del(echo.HeaderAuthorization)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(http.StatusUnauthorized, he.Code)
|
||||
|
||||
// Invalid Authorization header
|
||||
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(http.StatusUnauthorized, he.Code)
|
||||
if tc.expectErr != "" {
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
assert.EqualError(t, err, tc.expectErr)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusTeapot, res.Code)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tc.expectHeader != "" {
|
||||
assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicAuth_panic(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
mw := BasicAuth(nil)
|
||||
assert.NotNil(t, mw)
|
||||
})
|
||||
|
||||
mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
assert.NotNil(t, mw)
|
||||
}
|
||||
|
||||
func TestBasicAuthWithConfig_panic(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
|
||||
assert.NotNil(t, mw)
|
||||
})
|
||||
|
||||
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) {
|
||||
return true, nil
|
||||
}})
|
||||
assert.NotNil(t, mw)
|
||||
}
|
||||
|
Reference in New Issue
Block a user