mirror of
				https://github.com/labstack/echo.git
				synced 2025-10-30 23:57:38 +02:00 
			
		
		
		
	
							
								
								
									
										5
									
								
								echo.go
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								echo.go
									
									
									
									
									
								
							| @@ -461,6 +461,11 @@ func NewHTTPError(code int, msg ...string) *HTTPError { | ||||
| 	return he | ||||
| } | ||||
|  | ||||
| // SetCode sets code. | ||||
| func (e *HTTPError) SetCode(code int) { | ||||
| 	e.code = code | ||||
| } | ||||
|  | ||||
| // Code returns code. | ||||
| func (e *HTTPError) Code() int { | ||||
| 	return e.code | ||||
|   | ||||
| @@ -1,15 +1,15 @@ | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"github.com/labstack/echo" | ||||
| 	"io" | ||||
|  | ||||
| 	"github.com/labstack/echo" | ||||
| 	mw "github.com/labstack/echo/middleware" | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| 	e := echo.New() | ||||
| 	e.Use(mw.Logger()) | ||||
| 	e.Use(mw.Gzip()) | ||||
| 	e.WebSocket("/ws", func(c *echo.Context) error { | ||||
| 		io.Copy(c.Socket(), c.Socket()) | ||||
| 		return nil | ||||
|   | ||||
| @@ -2,59 +2,41 @@ package middleware | ||||
|  | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"github.com/dgrijalva/jwt-go" | ||||
| 	"github.com/labstack/echo" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| type ( | ||||
| 	AuthFunc func(string, string) bool | ||||
| 	BasicValidateFunc func(string, string) bool | ||||
| 	JWTValidateFunc   func(string, jwt.SigningMethod) ([]byte, error) | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	Basic = "Basic" | ||||
| 	Basic  = "Basic" | ||||
| 	Bearer = "Bearer" | ||||
| ) | ||||
|  | ||||
| // BasicAuth returns an HTTP basic authentication middleware. For valid credentials | ||||
| // it calls the next handler in the chain. | ||||
|  | ||||
| // BasicAuth returns an HTTP basic authentication middleware. | ||||
| // | ||||
| // For valid credentials it calls the next handler. | ||||
| // For invalid Authorization header it sends "404 - Bad Request" response. | ||||
| // For invalid credentials, it sends "401 - Unauthorized" response. | ||||
| func BasicAuth(fn AuthFunc) echo.HandlerFunc { | ||||
| func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc { | ||||
| 	return func(c *echo.Context) error { | ||||
| 		// Skip for WebSocket | ||||
| 		// Skip WebSocket | ||||
| 		if (c.Request().Header.Get(echo.Upgrade)) == echo.WebSocket { | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 		auth := c.Request().Header.Get(echo.Authorization) | ||||
| 		i := 0 | ||||
| 		code := http.StatusBadRequest | ||||
| 		l := len(Basic) | ||||
| 		he := echo.NewHTTPError(http.StatusBadRequest) | ||||
| 		println(auth) | ||||
|  | ||||
| 		for ; i < len(auth); i++ { | ||||
| 			c := auth[i] | ||||
| 			// Ignore empty spaces | ||||
| 			if c == ' ' { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// Check scheme | ||||
| 			if i < len(Basic) { | ||||
| 				// Ignore case | ||||
| 				if i == 0 { | ||||
| 					if c != Basic[i] && c != 'b' { | ||||
| 						break | ||||
| 					} | ||||
| 				} else { | ||||
| 					if c != Basic[i] { | ||||
| 						break | ||||
| 					} | ||||
| 				} | ||||
| 			} else { | ||||
| 				// Extract credentials | ||||
| 				b, err := base64.StdEncoding.DecodeString(auth[i:]) | ||||
| 				if err != nil { | ||||
| 					break | ||||
| 				} | ||||
| 		if len(auth) > l+1 && auth[:l] == Basic { | ||||
| 			b, err := base64.StdEncoding.DecodeString(auth[l+1:]) | ||||
| 			if err == nil { | ||||
| 				cred := string(b) | ||||
| 				for i := 0; i < len(cred); i++ { | ||||
| 					if cred[i] == ':' { | ||||
| @@ -62,12 +44,47 @@ func BasicAuth(fn AuthFunc) echo.HandlerFunc { | ||||
| 						if fn(cred[:i], cred[i+1:]) { | ||||
| 							return nil | ||||
| 						} | ||||
| 						code = http.StatusUnauthorized | ||||
| 						break | ||||
| 						he.SetCode(http.StatusUnauthorized) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return echo.NewHTTPError(code) | ||||
| 		return he | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // JWTAuth returns a JWT authentication middleware. | ||||
| // | ||||
| // For valid token it sets JWT claims in the context with key `_claims` and calls | ||||
| // the next handler. | ||||
| // For invalid Authorization header it sends "404 - Bad Request" response. | ||||
| // For invalid credentials, it sends "401 - Unauthorized" response. | ||||
| func JWTAuth(fn JWTValidateFunc) echo.HandlerFunc { | ||||
| 	return func(c *echo.Context) error { | ||||
| 		// Skip WebSocket | ||||
| 		if (c.Request().Header.Get(echo.Upgrade)) == echo.WebSocket { | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 		auth := c.Request().Header.Get("Authorization") | ||||
| 		l := len(Bearer) | ||||
| 		he := echo.NewHTTPError(http.StatusBadRequest) | ||||
|  | ||||
| 		if len(auth) > l+1 && auth[:l] == Bearer { | ||||
| 			t, err := jwt.Parse(auth[l+1:], func(token *jwt.Token) (interface{}, error) { | ||||
| 				// Lookup key and verify method | ||||
| 				if kid := token.Header["kid"]; kid != nil { | ||||
| 					return fn(kid.(string), token.Method) | ||||
| 				} | ||||
| 				return fn("", token.Method) | ||||
| 			}) | ||||
| 			if err == nil && t.Valid { | ||||
| 				c.Set("_claims", t.Claims) | ||||
| 				return nil | ||||
| 			} else { | ||||
| 				he.SetCode(http.StatusUnauthorized) | ||||
| 			} | ||||
| 		} | ||||
| 		return he | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -5,9 +5,11 @@ import ( | ||||
| 	"net/http" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/dgrijalva/jwt-go" | ||||
| 	"github.com/labstack/echo" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"net/http/httptest" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func TestBasicAuth(t *testing.T) { | ||||
| @@ -22,52 +24,94 @@ func TestBasicAuth(t *testing.T) { | ||||
| 	} | ||||
| 	ba := BasicAuth(fn) | ||||
|  | ||||
| 	//------------------- | ||||
| 	// Valid credentials | ||||
| 	//------------------- | ||||
|  | ||||
| 	auth := Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) | ||||
| 	req.Header.Set(echo.Authorization, auth) | ||||
| 	assert.NoError(t, ba(c)) | ||||
|  | ||||
| 	// Case insensitive | ||||
| 	auth = "basic " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) | ||||
| 	req.Header.Set(echo.Authorization, auth) | ||||
| 	assert.NoError(t, ba(c)) | ||||
|  | ||||
| 	//--------------------- | ||||
| 	// Invalid credentials | ||||
| 	//--------------------- | ||||
|  | ||||
| 	// Incorrect password | ||||
| 	auth = Basic + "  " + base64.StdEncoding.EncodeToString([]byte("joe:password")) | ||||
| 	auth = Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) | ||||
| 	req.Header.Set(echo.Authorization, auth) | ||||
| 	ba = BasicAuth(fn) | ||||
| 	he := ba(c).(*echo.HTTPError) | ||||
| 	assert.Equal(t, http.StatusUnauthorized, he.Code()) | ||||
|  | ||||
| 	// Empty Authorization header | ||||
| 	req.Header.Set(echo.Authorization, "") | ||||
| 	ba = BasicAuth(fn) | ||||
| 	he = ba(c).(*echo.HTTPError) | ||||
| 	assert.Equal(t, http.StatusBadRequest, he.Code()) | ||||
|  | ||||
| 	// Invalid Authorization header | ||||
| 	auth = base64.StdEncoding.EncodeToString([]byte(" :secret")) | ||||
| 	req.Header.Set(echo.Authorization, auth) | ||||
| 	ba = BasicAuth(fn) | ||||
| 	he = ba(c).(*echo.HTTPError) | ||||
| 	assert.Equal(t, http.StatusBadRequest, he.Code()) | ||||
|  | ||||
| 	// Invalid scheme | ||||
| 	auth = "Base " + base64.StdEncoding.EncodeToString([]byte(" :secret")) | ||||
| 	req.Header.Set(echo.Authorization, auth) | ||||
| 	ba = BasicAuth(fn) | ||||
| 	he = ba(c).(*echo.HTTPError) | ||||
| 	assert.Equal(t, http.StatusBadRequest, he.Code()) | ||||
|  | ||||
| 	// WebSocket | ||||
| 	c.Request().Header.Set(echo.Upgrade, echo.WebSocket) | ||||
| 	ba = BasicAuth(fn) | ||||
| 	assert.NoError(t, ba(c)) | ||||
| } | ||||
|  | ||||
| func TestJWTAuth(t *testing.T) { | ||||
| 	req, _ := http.NewRequest(echo.GET, "/", nil) | ||||
| 	rec := httptest.NewRecorder() | ||||
| 	c := echo.NewContext(req, echo.NewResponse(rec), echo.New()) | ||||
| 	key := []byte("key") | ||||
| 	fn := func(kid string, method jwt.SigningMethod) ([]byte, error) { | ||||
| 		return key, nil | ||||
| 	} | ||||
| 	ja := JWTAuth(fn) | ||||
| 	token := jwt.New(jwt.SigningMethodHS256) | ||||
| 	token.Claims["foo"] = "bar" | ||||
| 	token.Claims["exp"] = time.Now().Add(time.Hour * 72).Unix() | ||||
| 	ts, err := token.SignedString(key) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	// Valid credentials | ||||
| 	auth := Bearer + " " + ts | ||||
| 	req.Header.Set(echo.Authorization, auth) | ||||
| 	assert.NoError(t, ja(c)) | ||||
|  | ||||
| 	//--------------------- | ||||
| 	// Invalid credentials | ||||
| 	//--------------------- | ||||
|  | ||||
| 	// Expired token | ||||
| 	token.Claims["exp"] = time.Now().Add(-time.Second).Unix() | ||||
| 	ts, err = token.SignedString(key) | ||||
| 	assert.NoError(t, err) | ||||
| 	auth = Bearer + " " + ts | ||||
| 	req.Header.Set(echo.Authorization, auth) | ||||
| 	he := ja(c).(*echo.HTTPError) | ||||
| 	assert.Equal(t, http.StatusUnauthorized, he.Code()) | ||||
|  | ||||
| 	// Empty Authorization header | ||||
| 	req.Header.Set(echo.Authorization, "") | ||||
| 	he = ja(c).(*echo.HTTPError) | ||||
| 	assert.Equal(t, http.StatusBadRequest, he.Code()) | ||||
|  | ||||
| 	// Invalid Authorization header | ||||
| 	auth = "token" | ||||
| 	req.Header.Set(echo.Authorization, auth) | ||||
| 	he = ja(c).(*echo.HTTPError) | ||||
| 	assert.Equal(t, http.StatusBadRequest, he.Code()) | ||||
|  | ||||
| 	// Invalid scheme | ||||
| 	auth = "Bear token" | ||||
| 	req.Header.Set(echo.Authorization, auth) | ||||
| 	he = ja(c).(*echo.HTTPError) | ||||
| 	assert.Equal(t, http.StatusBadRequest, he.Code()) | ||||
|  | ||||
| 	// WebSocket | ||||
| 	c.Request().Header.Set(echo.Upgrade, echo.WebSocket) | ||||
| 	assert.NoError(t, ja(c)) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user