From 21de240843b8b1d5442776115141ea5f703ed45b Mon Sep 17 00:00:00 2001
From: bing liu <liubing0427@users.noreply.github.com>
Date: Fri, 5 Jan 2024 11:01:17 +0800
Subject: [PATCH] feat: add server option for NotFoundHandler and
 MethodNotAllowedHandler (#3131)

Co-authored-by: Miles Liu <milesliu@birentech.com>
---
 transport/http/server.go      | 16 ++++++++++++++--
 transport/http/server_test.go | 16 ++++++++++++++++
 2 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/transport/http/server.go b/transport/http/server.go
index c967e4a1b..9c5550b22 100644
--- a/transport/http/server.go
+++ b/transport/http/server.go
@@ -141,6 +141,18 @@ func PathPrefix(prefix string) ServerOption {
 	}
 }
 
+func NotFoundHandler(handler http.Handler) ServerOption {
+	return func(s *Server) {
+		s.router.NotFoundHandler = handler
+	}
+}
+
+func MethodNotAllowedHandler(handler http.Handler) ServerOption {
+	return func(s *Server) {
+		s.router.MethodNotAllowedHandler = handler
+	}
+}
+
 // Server is an HTTP server wrapper.
 type Server struct {
 	*http.Server
@@ -177,12 +189,12 @@ func NewServer(opts ...ServerOption) *Server {
 		strictSlash: true,
 		router:      mux.NewRouter(),
 	}
+	srv.router.NotFoundHandler = http.DefaultServeMux
+	srv.router.MethodNotAllowedHandler = http.DefaultServeMux
 	for _, o := range opts {
 		o(srv)
 	}
 	srv.router.StrictSlash(srv.strictSlash)
-	srv.router.NotFoundHandler = http.DefaultServeMux
-	srv.router.MethodNotAllowedHandler = http.DefaultServeMux
 	srv.router.Use(srv.filter())
 	srv.Server = &http.Server{
 		Handler:   FilterChain(srv.filters...)(srv.router),
diff --git a/transport/http/server_test.go b/transport/http/server_test.go
index 781c06a71..444775048 100644
--- a/transport/http/server_test.go
+++ b/transport/http/server_test.go
@@ -371,3 +371,19 @@ func TestListener(t *testing.T) {
 		t.Errorf("expected not empty")
 	}
 }
+
+func TestNotFoundHandler(t *testing.T) {
+	mux := http.NewServeMux()
+	srv := NewServer(NotFoundHandler(mux))
+	if !reflect.DeepEqual(srv.router.NotFoundHandler, mux) {
+		t.Errorf("expected %v got %v", mux, srv.router.NotFoundHandler)
+	}
+}
+
+func TestMethodNotAllowedHandler(t *testing.T) {
+	mux := http.NewServeMux()
+	srv := NewServer(MethodNotAllowedHandler(mux))
+	if !reflect.DeepEqual(srv.router.MethodNotAllowedHandler, mux) {
+		t.Errorf("expected %v got %v", mux, srv.router.MethodNotAllowedHandler)
+	}
+}