diff --git a/app.go b/app.go index fe060f383..0d8e4d3e4 100644 --- a/app.go +++ b/app.go @@ -53,14 +53,7 @@ func (a *App) Run() error { if err != nil { return err } - ctx := NewContext(a.ctx, AppInfo{ - ID: instance.ID, - Name: instance.Name, - Version: instance.Version, - Metadata: instance.Metadata, - Endpoints: instance.Endpoints, - }) - eg, ctx := errgroup.WithContext(ctx) + eg, ctx := errgroup.WithContext(a.ctx) wg := sync.WaitGroup{} for _, srv := range a.opts.servers { srv := srv diff --git a/context.go b/context.go deleted file mode 100644 index e97b0fd16..000000000 --- a/context.go +++ /dev/null @@ -1,27 +0,0 @@ -package kratos - -import ( - "context" -) - -// AppInfo is application context value. -type AppInfo struct { - ID string - Name string - Version string - Metadata map[string]string - Endpoints []string -} - -type appKey struct{} - -// NewContext returns a new Context that carries value. -func NewContext(ctx context.Context, s AppInfo) context.Context { - return context.WithValue(ctx, appKey{}, s) -} - -// FromContext returns the Transport value stored in ctx, if any. -func FromContext(ctx context.Context) (s AppInfo, ok bool) { - s, ok = ctx.Value(appKey{}).(AppInfo) - return -} diff --git a/internal/context/context.go b/internal/context/context.go deleted file mode 100644 index 0b2c3388e..000000000 --- a/internal/context/context.go +++ /dev/null @@ -1,115 +0,0 @@ -package context - -import ( - "context" - "sync" - "sync/atomic" - "time" -) - -type mergeCtx struct { - parent1, parent2 context.Context - - done chan struct{} - doneMark uint32 - doneOnce sync.Once - doneErr error - - cancelCh chan struct{} - cancelOnce sync.Once -} - -// Merge merges two contexts into one. -func Merge(parent1, parent2 context.Context) (context.Context, context.CancelFunc) { - mc := &mergeCtx{ - parent1: parent1, - parent2: parent2, - done: make(chan struct{}), - cancelCh: make(chan struct{}), - } - select { - case <-parent1.Done(): - mc.finish(parent1.Err()) - case <-parent2.Done(): - mc.finish(parent2.Err()) - default: - go mc.wait() - } - return mc, mc.cancel -} - -func (mc *mergeCtx) finish(err error) error { - mc.doneOnce.Do(func() { - mc.doneErr = err - atomic.StoreUint32(&mc.doneMark, 1) - close(mc.done) - }) - return mc.doneErr -} - -func (mc *mergeCtx) wait() { - var err error - select { - case <-mc.parent1.Done(): - err = mc.parent1.Err() - case <-mc.parent2.Done(): - err = mc.parent2.Err() - case <-mc.cancelCh: - err = context.Canceled - } - mc.finish(err) -} - -func (mc *mergeCtx) cancel() { - mc.cancelOnce.Do(func() { - close(mc.cancelCh) - }) -} - -// Done implements context.Context. -func (mc *mergeCtx) Done() <-chan struct{} { - return mc.done -} - -// Err implements context.Context. -func (mc *mergeCtx) Err() error { - if atomic.LoadUint32(&mc.doneMark) != 0 { - return mc.doneErr - } - var err error - select { - case <-mc.parent1.Done(): - err = mc.parent1.Err() - case <-mc.parent2.Done(): - err = mc.parent2.Err() - case <-mc.cancelCh: - err = context.Canceled - default: - return nil - } - return mc.finish(err) -} - -// Deadline implements context.Context. -func (mc *mergeCtx) Deadline() (time.Time, bool) { - d1, ok1 := mc.parent1.Deadline() - d2, ok2 := mc.parent2.Deadline() - switch { - case !ok1: - return d2, ok2 - case !ok2: - return d1, ok1 - case d1.Before(d2): - return d1, true - default: - return d2, true - } -} - -// Value implements context.Context. -func (mc *mergeCtx) Value(key interface{}) interface{} { - if v := mc.parent1.Value(key); v != nil { - return v - } - return mc.parent2.Value(key) -} diff --git a/transport/grpc/server.go b/transport/grpc/server.go index 21868f2a0..89a811c02 100644 --- a/transport/grpc/server.go +++ b/transport/grpc/server.go @@ -8,7 +8,6 @@ import ( "time" apimd "github.com/go-kratos/kratos/v2/api/metadata" - ic "github.com/go-kratos/kratos/v2/internal/context" "github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/middleware" @@ -79,7 +78,6 @@ func Options(opts ...grpc.ServerOption) ServerOption { // Server is a gRPC server wrapper. type Server struct { *grpc.Server - ctx context.Context lis net.Listener once sync.Once err error @@ -158,7 +156,6 @@ func (s *Server) Start(ctx context.Context) error { if _, err := s.Endpoint(); err != nil { return err } - s.ctx = ctx s.log.Infof("[gRPC] server listening on: %s", s.lis.Addr().String()) s.health.Resume() return s.Serve(s.lis) @@ -174,7 +171,7 @@ func (s *Server) Stop(ctx context.Context) error { func (s *Server) unaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - ctx, cancel := ic.Merge(ctx, s.ctx) + ctx, cancel := context.WithCancel(ctx) defer cancel() md, _ := grpcmd.FromIncomingContext(ctx) ctx = transport.NewServerContext(ctx, &Transport{ diff --git a/transport/http/server.go b/transport/http/server.go index 709fe4c7e..365ec2dd2 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -9,7 +9,6 @@ import ( "sync" "time" - ic "github.com/go-kratos/kratos/v2/internal/context" "github.com/go-kratos/kratos/v2/internal/host" "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/middleware" @@ -90,7 +89,6 @@ func ErrorEncoder(en EncodeErrorFunc) ServerOption { // Server is an HTTP server wrapper. type Server struct { *http.Server - ctx context.Context lis net.Listener once sync.Once endpoint *url.URL @@ -157,7 +155,7 @@ func (s *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { func (s *Server) filter() mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ctx, cancel := ic.Merge(req.Context(), s.ctx) + ctx, cancel := context.WithCancel(req.Context()) defer cancel() if s.timeout > 0 { ctx, cancel = context.WithTimeout(ctx, s.timeout) @@ -216,7 +214,6 @@ func (s *Server) Start(ctx context.Context) error { if _, err := s.Endpoint(); err != nil { return err } - s.ctx = ctx s.log.Infof("[HTTP] server listening on: %s", s.lis.Addr().String()) if err := s.Serve(s.lis); !errors.Is(err, http.ErrServerClosed) { return err