1
0
mirror of https://github.com/imgproxy/imgproxy.git synced 2026-04-23 19:41:06 +02:00
Files
imgproxy/server/server_test.go
2026-03-23 22:35:07 +06:00

282 lines
6.8 KiB
Go

package server_test
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/imgproxy/imgproxy/v3/errctx"
"github.com/imgproxy/imgproxy/v3/errorreport"
"github.com/imgproxy/imgproxy/v3/httpheaders"
"github.com/imgproxy/imgproxy/v3/monitoring"
"github.com/imgproxy/imgproxy/v3/server"
"github.com/imgproxy/imgproxy/v3/testutil"
)
type ServerTestSuite struct {
testutil.LazySuite
config testutil.LazyObj[*server.Config]
router testutil.LazyObj[*server.Router]
}
func (s *ServerTestSuite) SetupSuite() {
s.config, _ = testutil.NewLazySuiteObj(s, func() (*server.Config, error) {
c := server.NewDefaultConfig()
c.Bind = "127.0.0.1:0" // Use port 0 for auto-assignment
return &c, nil
})
s.router, _ = testutil.NewLazySuiteObj(s, func() (*server.Router, error) {
mc := monitoring.NewDefaultConfig()
m, err := monitoring.New(s.T().Context(), &mc, 1)
if err != nil {
return nil, err
}
erCfg := errorreport.NewDefaultConfig()
er, err := errorreport.New(&erCfg)
if err != nil {
return nil, err
}
return server.NewRouter(s.config(), m, er)
})
}
func (s *ServerTestSuite) SetupSubTest() {
s.ResetLazyObjects()
}
func (s *ServerTestSuite) TestStartServerWithInvalidBind() {
ctx, cancel := context.WithCancel(s.T().Context())
// Track if cancel was called using atomic
var cancelCalled atomic.Bool
cancelWrapper := func() {
cancel()
cancelCalled.Store(true)
}
s.config().Bind = "-1.-1.-1.-1" // Invalid address
srv, err := server.Start(cancelWrapper, s.router())
s.Require().Error(err)
s.Nil(srv)
s.Contains(err.Error(), "can't start server")
// Check if cancel was called using Eventually
s.Require().Eventually(cancelCalled.Load, 100*time.Millisecond, 10*time.Millisecond)
// Also verify the context was cancelled
s.Require().Eventually(func() bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}, 100*time.Millisecond, 10*time.Millisecond)
}
func (s *ServerTestSuite) TestShutdown() {
_, cancel := context.WithCancel(context.Background())
defer cancel()
srv, err := server.Start(cancel, s.router())
s.Require().NoError(err)
s.NotNil(srv)
// Test graceful shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(s.T().Context(), 10*time.Second)
defer shutdownCancel()
// Should not panic or hang
s.NotPanics(func() {
srv.Shutdown(shutdownCtx)
})
}
func (s *ServerTestSuite) TestWithCORS() {
tests := []struct {
name string
corsAllowOrigin string
expectedOrigin string
expectedMethods string
}{
{
name: "WithCORSOrigin",
corsAllowOrigin: "https://example.com",
expectedOrigin: "https://example.com",
expectedMethods: "GET, OPTIONS",
},
{
name: "NoCORSOrigin",
corsAllowOrigin: "",
expectedOrigin: "",
expectedMethods: "",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
s.config().CORSAllowOrigin = tt.corsAllowOrigin
s.router().GET("/test", s.router().WithCORS(s.mockHandler))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rw := httptest.NewRecorder()
s.router().ServeHTTP(rw, req)
s.Equal(tt.expectedOrigin, rw.Header().Get(httpheaders.AccessControlAllowOrigin))
s.Equal(tt.expectedMethods, rw.Header().Get(httpheaders.AccessControlAllowMethods))
})
}
}
func (s *ServerTestSuite) TestWithSecret() {
tests := []struct {
name string
secret string
authHeader string
expectStatus int
}{
{
name: "ValidSecret",
secret: "test-secret",
authHeader: "Bearer test-secret",
expectStatus: http.StatusOK,
},
{
name: "InvalidSecret",
secret: "foo-secret",
authHeader: "Bearer wrong-secret",
expectStatus: http.StatusForbidden,
},
{
name: "NoSecretConfigured",
secret: "",
authHeader: "",
expectStatus: http.StatusOK,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
s.config().Secret = tt.secret
s.router().GET("/test", s.router().WithReportError(s.router().WithSecret(s.mockHandler)))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
if tt.authHeader != "" {
req.Header.Set(httpheaders.Authorization, tt.authHeader)
}
rw := httptest.NewRecorder()
s.router().ServeHTTP(rw, req)
s.Equal(tt.expectStatus, rw.Code)
})
}
}
func (s *ServerTestSuite) TestIntoSuccess() {
mockHandler := func(reqID string, rw server.ResponseWriter, r *http.Request) *server.Error {
rw.WriteHeader(http.StatusOK)
return nil
}
s.router().GET("/test", s.router().WithReportError(mockHandler))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rw := httptest.NewRecorder()
s.router().ServeHTTP(rw, req)
s.Equal(http.StatusOK, rw.Code)
}
func (s *ServerTestSuite) TestIntoWithError() {
testError := errctx.NewTextError("test error", 0)
mockHandler := func(reqID string, rw server.ResponseWriter, r *http.Request) *server.Error {
return server.NewError(testError, "test-category")
}
s.router().GET("/test", s.router().WithReportError(mockHandler))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rw := httptest.NewRecorder()
s.router().ServeHTTP(rw, req)
s.Equal(http.StatusInternalServerError, rw.Code)
s.Equal("text/plain", rw.Header().Get(httpheaders.ContentType))
}
func (s *ServerTestSuite) TestIntoPanicWithError() {
testError := errors.New("panic error")
mockHandler := func(reqID string, rw server.ResponseWriter, r *http.Request) *server.Error {
panic(testError)
}
s.router().GET("/test", s.router().WithReportError(s.router().WithPanic(mockHandler)))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rw := httptest.NewRecorder()
s.NotPanics(func() {
s.router().ServeHTTP(rw, req)
})
s.Equal(http.StatusInternalServerError, rw.Code)
}
func (s *ServerTestSuite) TestIntoPanicWithAbortHandler() {
mockHandler := func(reqID string, rw server.ResponseWriter, r *http.Request) *server.Error {
panic(http.ErrAbortHandler)
}
s.router().GET("/test", s.router().WithPanic(mockHandler))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rw := httptest.NewRecorder()
// Should re-panic with ErrAbortHandler
s.Panics(func() {
s.router().ServeHTTP(rw, req)
})
}
func (s *ServerTestSuite) TestIntoPanicWithNonError() {
mockHandler := func(reqID string, rw server.ResponseWriter, r *http.Request) *server.Error {
panic("string panic")
}
s.router().GET("/test", s.router().WithReportError(s.router().WithPanic(mockHandler)))
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rw := httptest.NewRecorder()
s.NotPanics(func() {
s.router().ServeHTTP(rw, req)
})
s.Equal(http.StatusInternalServerError, rw.Code)
}
func (s *ServerTestSuite) mockHandler(reqID string, rw server.ResponseWriter, r *http.Request) *server.Error {
return nil
}
func TestServerTestSuite(t *testing.T) {
suite.Run(t, new(ServerTestSuite))
}