From 52876d3e207381569e4b27f7d76865e134c56f78 Mon Sep 17 00:00:00 2001 From: longxboy Date: Wed, 20 Oct 2021 22:03:39 +0800 Subject: [PATCH] test: add app and transport test (#1572) * add app and transport test --- app.go | 10 ++++-- app_test.go | 29 ++++++++++++++++ transport/transport_test.go | 67 +++++++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 transport/transport_test.go diff --git a/app.go b/app.go index ab425b643..aa7ba8248 100644 --- a/app.go +++ b/app.go @@ -31,6 +31,7 @@ type App struct { opts options ctx context.Context cancel func() + lk sync.Mutex instance *registry.ServiceInstance } @@ -104,7 +105,9 @@ func (a *App) Run() error { if err := a.opts.registrar.Register(rctx, instance); err != nil { return err } + a.lk.Lock() a.instance = instance + a.lk.Unlock() } c := make(chan os.Signal, 1) signal.Notify(c, a.opts.sigs...) @@ -130,10 +133,13 @@ func (a *App) Run() error { // Stop gracefully stops the application. func (a *App) Stop() error { - if a.opts.registrar != nil && a.instance != nil { + a.lk.Lock() + instance := a.instance + a.lk.Unlock() + if a.opts.registrar != nil && instance != nil { ctx, cancel := context.WithTimeout(a.opts.ctx, a.opts.registrarTimeout) defer cancel() - if err := a.opts.registrar.Deregister(ctx, a.instance); err != nil { + if err := a.opts.registrar.Deregister(ctx, instance); err != nil { return err } } diff --git a/app_test.go b/app_test.go index 37195e7b1..04aeecfad 100644 --- a/app_test.go +++ b/app_test.go @@ -2,7 +2,9 @@ package kratos import ( "context" + "fmt" "reflect" + "sync" "testing" "time" @@ -12,6 +14,32 @@ import ( "github.com/stretchr/testify/assert" ) +type mockRegistry struct { + lk sync.Mutex + service map[string]*registry.ServiceInstance +} + +func (r *mockRegistry) Register(ctx context.Context, service *registry.ServiceInstance) error { + if service == nil || service.ID == "" { + return fmt.Errorf("no service id") + } + r.lk.Lock() + defer r.lk.Unlock() + r.service[service.ID] = service + return nil +} + +// Deregister the registration. +func (r *mockRegistry) Deregister(ctx context.Context, service *registry.ServiceInstance) error { + r.lk.Lock() + defer r.lk.Unlock() + if r.service[service.ID] == nil { + return fmt.Errorf("deregister service not found") + } + delete(r.service, service.ID) + return nil +} + func TestApp(t *testing.T) { hs := http.NewServer() gs := grpc.NewServer() @@ -19,6 +47,7 @@ func TestApp(t *testing.T) { Name("kratos"), Version("v1.0.0"), Server(hs, gs), + Registrar(&mockRegistry{service: make(map[string]*registry.ServiceInstance)}), ) time.AfterFunc(time.Second, func() { _ = app.Stop() diff --git a/transport/transport_test.go b/transport/transport_test.go new file mode 100644 index 000000000..a453a8e51 --- /dev/null +++ b/transport/transport_test.go @@ -0,0 +1,67 @@ +package transport + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +// mockTransport is a gRPC transport. +type mockTransport struct { + endpoint string + operation string +} + +// Kind returns the transport kind. +func (tr *mockTransport) Kind() Kind { + return KindGRPC +} + +// Endpoint returns the transport endpoint. +func (tr *mockTransport) Endpoint() string { + return tr.endpoint +} + +// Operation returns the transport operation. +func (tr *mockTransport) Operation() string { + return tr.operation +} + +// RequestHeader returns the request header. +func (tr *mockTransport) RequestHeader() Header { + return nil +} + +// ReplyHeader returns the reply header. +func (tr *mockTransport) ReplyHeader() Header { + return nil +} + +func TestServerTransport(t *testing.T) { + ctx := context.Background() + + ctx = NewServerContext(ctx, &mockTransport{endpoint: "test_endpoint"}) + tr, ok := FromServerContext(ctx) + + assert.Equal(t, true, ok) + assert.NotNil(t, tr) + mtr, ok := tr.(*mockTransport) + assert.Equal(t, true, ok) + assert.NotNil(t, mtr) + assert.Equal(t, mtr.endpoint, "test_endpoint") +} + +func TestClientTransport(t *testing.T) { + ctx := context.Background() + + ctx = NewClientContext(ctx, &mockTransport{endpoint: "test_endpoint"}) + tr, ok := FromClientContext(ctx) + + assert.Equal(t, true, ok) + assert.NotNil(t, tr) + mtr, ok := tr.(*mockTransport) + assert.Equal(t, true, ok) + assert.NotNil(t, mtr) + assert.Equal(t, mtr.endpoint, "test_endpoint") +}