From 21de240843b8b1d5442776115141ea5f703ed45b Mon Sep 17 00:00:00 2001 From: bing liu <liubing0427@users.noreply.github.com> Date: Fri, 5 Jan 2024 11:01:17 +0800 Subject: [PATCH] feat: add server option for NotFoundHandler and MethodNotAllowedHandler (#3131) Co-authored-by: Miles Liu <milesliu@birentech.com> --- transport/http/server.go | 16 ++++++++++++++-- transport/http/server_test.go | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/transport/http/server.go b/transport/http/server.go index c967e4a1b..9c5550b22 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -141,6 +141,18 @@ func PathPrefix(prefix string) ServerOption { } } +func NotFoundHandler(handler http.Handler) ServerOption { + return func(s *Server) { + s.router.NotFoundHandler = handler + } +} + +func MethodNotAllowedHandler(handler http.Handler) ServerOption { + return func(s *Server) { + s.router.MethodNotAllowedHandler = handler + } +} + // Server is an HTTP server wrapper. type Server struct { *http.Server @@ -177,12 +189,12 @@ func NewServer(opts ...ServerOption) *Server { strictSlash: true, router: mux.NewRouter(), } + srv.router.NotFoundHandler = http.DefaultServeMux + srv.router.MethodNotAllowedHandler = http.DefaultServeMux for _, o := range opts { o(srv) } srv.router.StrictSlash(srv.strictSlash) - srv.router.NotFoundHandler = http.DefaultServeMux - srv.router.MethodNotAllowedHandler = http.DefaultServeMux srv.router.Use(srv.filter()) srv.Server = &http.Server{ Handler: FilterChain(srv.filters...)(srv.router), diff --git a/transport/http/server_test.go b/transport/http/server_test.go index 781c06a71..444775048 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -371,3 +371,19 @@ func TestListener(t *testing.T) { t.Errorf("expected not empty") } } + +func TestNotFoundHandler(t *testing.T) { + mux := http.NewServeMux() + srv := NewServer(NotFoundHandler(mux)) + if !reflect.DeepEqual(srv.router.NotFoundHandler, mux) { + t.Errorf("expected %v got %v", mux, srv.router.NotFoundHandler) + } +} + +func TestMethodNotAllowedHandler(t *testing.T) { + mux := http.NewServeMux() + srv := NewServer(MethodNotAllowedHandler(mux)) + if !reflect.DeepEqual(srv.router.MethodNotAllowedHandler, mux) { + t.Errorf("expected %v got %v", mux, srv.router.MethodNotAllowedHandler) + } +}