package middleware import ( "net/http" "net/url" "strings" "testing" "github.com/labstack/echo" "github.com/labstack/echo/test" "github.com/labstack/gommon/random" "github.com/stretchr/testify/assert" ) func TestCSRF(t *testing.T) { e := echo.New() req := test.NewRequest(echo.GET, "/", nil) rec := test.NewResponseRecorder() c := e.NewContext(req, rec) csrf := CSRFWithConfig(CSRFConfig{ TokenLength: 16, }) h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) // Generate CSRF token h(c) assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") // Without CSRF cookie req = test.NewRequest(echo.POST, "/", nil) rec = test.NewResponseRecorder() c = e.NewContext(req, rec) assert.Error(t, h(c)) // Empty/invalid CSRF token req = test.NewRequest(echo.POST, "/", nil) rec = test.NewResponseRecorder() c = e.NewContext(req, rec) req.Header().Set(echo.HeaderXCSRFToken, "") assert.Error(t, h(c)) // Valid CSRF token token := random.String(16) req.Header().Set(echo.HeaderCookie, "_csrf="+token) req.Header().Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { assert.Equal(t, http.StatusOK, rec.Status()) } } func TestCSRFTokenFromForm(t *testing.T) { f := make(url.Values) f.Set("csrf", "token") e := echo.New() req := test.NewRequest(echo.POST, "/", strings.NewReader(f.Encode())) req.Header().Add(echo.HeaderContentType, echo.MIMEApplicationForm) c := e.NewContext(req, nil) token, err := csrfTokenFromForm("csrf")(c) if assert.NoError(t, err) { assert.Equal(t, "token", token) } token, err = csrfTokenFromForm("invalid")(c) assert.Error(t, err) } func TestCSRFTokenFromQuery(t *testing.T) { q := make(url.Values) q.Set("csrf", "token") e := echo.New() req := test.NewRequest(echo.GET, "/?"+q.Encode(), nil) req.Header().Add(echo.HeaderContentType, echo.MIMEApplicationForm) c := e.NewContext(req, nil) token, err := csrfTokenFromQuery("csrf")(c) if assert.NoError(t, err) { assert.Equal(t, "token", token) } token, err = csrfTokenFromQuery("invalid")(c) assert.Error(t, err) csrfTokenFromQuery("csrf") }