1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-22 20:06:21 +02:00
echo/middleware/basic_auth_test.go

159 lines
4.5 KiB
Go
Raw Permalink Normal View History

package middleware
import (
"encoding/base64"
2021-07-15 22:34:01 +02:00
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
2021-07-15 22:34:01 +02:00
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestBasicAuth(t *testing.T) {
2021-07-15 22:34:01 +02:00
validatorFunc := func(c echo.Context, u, p string) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
}
2021-07-15 22:34:01 +02:00
if u == "error" {
return false, errors.New(p)
}
return false, nil
}
2021-07-15 22:34:01 +02:00
defaultConfig := BasicAuthConfig{Validator: validatorFunc}
2021-07-15 22:34:01 +02:00
var testCases = []struct {
name string
givenConfig BasicAuthConfig
whenAuth []string
expectHeader string
expectErr string
}{
{
name: "ok",
givenConfig: defaultConfig,
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, multiple",
givenConfig: defaultConfig,
whenAuth: []string{
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
basic + " NOT_BASE64",
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
},
},
{
name: "nok, invalid Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, not base64 Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3",
2021-07-15 22:34:01 +02:00
},
{
name: "nok, missing Authorization header",
givenConfig: defaultConfig,
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "ok, realm",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, realm, case-insensitive header scheme",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "nok, realm, invalid Authorization header",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm="someRealm"`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, validator func returns an error",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
expectErr: "my_error",
},
{
name: "ok, skipped",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool {
return true
}},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
},
}
2018-10-14 09:18:44 +02:00
2021-07-15 22:34:01 +02:00
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
2021-07-15 22:34:01 +02:00
config := tc.givenConfig
2021-07-15 22:34:01 +02:00
mw, err := config.ToMiddleware()
assert.NoError(t, err)
2021-07-15 22:34:01 +02:00
h := mw(func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
2021-07-15 22:34:01 +02:00
if len(tc.whenAuth) != 0 {
for _, a := range tc.whenAuth {
req.Header.Add(echo.HeaderAuthorization, a)
}
}
err = h(c)
if tc.expectErr != "" {
assert.Equal(t, http.StatusOK, res.Code)
assert.EqualError(t, err, tc.expectErr)
} else {
assert.Equal(t, http.StatusTeapot, res.Code)
assert.NoError(t, err)
}
if tc.expectHeader != "" {
assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
}
})
}
}
2021-07-15 22:34:01 +02:00
func TestBasicAuth_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BasicAuth(nil)
assert.NotNil(t, mw)
})
mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) {
return true, nil
})
assert.NotNil(t, mw)
}
func TestBasicAuthWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
assert.NotNil(t, mw)
})
2021-07-15 22:34:01 +02:00
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) {
return true, nil
}})
assert.NotNil(t, mw)
}