diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 748302c24..0ae458d6a 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "net" "net/url" - "sync" "time" "github.com/go-kratos/kratos/v2/internal/endpoint" @@ -75,6 +74,13 @@ func TLSConfig(c *tls.Config) ServerOption { } } +// Listener with server lis +func Listener(lis net.Listener) ServerOption { + return func(s *Server) { + s.lis = lis + } +} + // UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the server. func UnaryInterceptor(in ...grpc.UnaryServerInterceptor) ServerOption { return func(s *Server) { @@ -95,7 +101,6 @@ type Server struct { baseCtx context.Context tlsConf *tls.Config lis net.Listener - once sync.Once err error network string address string @@ -139,6 +144,8 @@ func NewServer(opts ...ServerOption) *Server { } srv.Server = grpc.NewServer(grpcOpts...) srv.metadata = apimd.NewServer(srv.Server) + // listen and endpoint + srv.err = srv.listenAndEndpoint() // internal register grpc_health_v1.RegisterHealthServer(srv.Server, srv.health) apimd.RegisterMetadataServer(srv.Server, srv.metadata) @@ -150,21 +157,6 @@ func NewServer(opts ...ServerOption) *Server { // examples: // grpc://127.0.0.1:9000?isSecure=false func (s *Server) Endpoint() (*url.URL, error) { - s.once.Do(func() { - lis, err := net.Listen(s.network, s.address) - if err != nil { - s.err = err - return - } - addr, err := host.Extract(s.address, lis) - if err != nil { - _ = lis.Close() - s.err = err - return - } - s.lis = lis - s.endpoint = endpoint.NewEndpoint("grpc", addr, s.tlsConf != nil) - }) if s.err != nil { return nil, s.err } @@ -173,8 +165,8 @@ func (s *Server) Endpoint() (*url.URL, error) { // Start start the gRPC server. func (s *Server) Start(ctx context.Context) error { - if _, err := s.Endpoint(); err != nil { - return err + if s.err != nil { + return s.err } s.baseCtx = ctx s.log.Infof("[gRPC] server listening on: %s", s.lis.Addr().String()) @@ -190,6 +182,23 @@ func (s *Server) Stop(ctx context.Context) error { return nil } +func (s *Server) listenAndEndpoint() error { + if s.lis == nil { + lis, err := net.Listen(s.network, s.address) + if err != nil { + return err + } + s.lis = lis + } + addr, err := host.Extract(s.address, s.lis) + if err != nil { + _ = s.lis.Close() + return err + } + s.endpoint = endpoint.NewEndpoint("grpc", addr, s.tlsConf != nil) + return nil +} + func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { ctx, cancel := ic.Merge(ctx, s.baseCtx) diff --git a/transport/grpc/server_test.go b/transport/grpc/server_test.go index 87c3483cc..14079ce4b 100644 --- a/transport/grpc/server_test.go +++ b/transport/grpc/server_test.go @@ -117,11 +117,9 @@ func TestNetwork(t *testing.T) { } func TestAddress(t *testing.T) { - o := &Server{} v := "abc" - Address(v)(o) + o := NewServer(Address(v)) assert.Equal(t, v, o.address) - u, err := o.Endpoint() assert.NotNil(t, err) assert.Nil(t, u) diff --git a/transport/http/server.go b/transport/http/server.go index da8a29fe6..f30268cfe 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -7,7 +7,6 @@ import ( "net" "net/http" "net/url" - "sync" "time" "github.com/go-kratos/kratos/v2/internal/endpoint" @@ -107,12 +106,18 @@ func StrictSlash(strictSlash bool) ServerOption { } } +// Listener with server lis +func Listener(lis net.Listener) ServerOption { + return func(s *Server) { + s.lis = lis + } +} + // Server is an HTTP server wrapper. type Server struct { *http.Server lis net.Listener tlsConf *tls.Config - once sync.Once endpoint *url.URL err error network string @@ -149,6 +154,7 @@ func NewServer(opts ...ServerOption) *Server { Handler: FilterChain(srv.filters...)(srv.router), TLSConfig: srv.tlsConf, } + srv.err = srv.listenAndEndpoint() return srv } @@ -219,22 +225,6 @@ func (s *Server) filter() mux.MiddlewareFunc { // examples: // http://127.0.0.1:8000?isSecure=false func (s *Server) Endpoint() (*url.URL, error) { - s.once.Do(func() { - lis, err := net.Listen(s.network, s.address) - if err != nil { - s.err = err - return - } - addr, err := host.Extract(s.address, lis) - if err != nil { - lis.Close() - s.err = err - return - } - s.lis = lis - - s.endpoint = endpoint.NewEndpoint("http", addr, s.tlsConf != nil) - }) if s.err != nil { return nil, s.err } @@ -243,8 +233,8 @@ func (s *Server) Endpoint() (*url.URL, error) { // Start start the HTTP server. func (s *Server) Start(ctx context.Context) error { - if _, err := s.Endpoint(); err != nil { - return err + if s.err != nil { + return s.err } s.BaseContext = func(net.Listener) context.Context { return ctx @@ -267,3 +257,20 @@ func (s *Server) Stop(ctx context.Context) error { s.log.Info("[HTTP] server stopping") return s.Shutdown(ctx) } + +func (s *Server) listenAndEndpoint() error { + if s.lis == nil { + lis, err := net.Listen(s.network, s.address) + if err != nil { + return err + } + s.lis = lis + } + addr, err := host.Extract(s.address, s.lis) + if err != nil { + _ = s.lis.Close() + return err + } + s.endpoint = endpoint.NewEndpoint("http", addr, s.tlsConf != nil) + return nil +}