From c2e7ab8d413b1156b426a622a02768e0229d6d0a Mon Sep 17 00:00:00 2001 From: Gani Georgiev Date: Thu, 21 Nov 2024 12:11:00 +0200 Subject: [PATCH] fixed oauth2 redirect test --- apis/record_auth_with_oauth2_redirect_test.go | 50 +++++++++++++------ tools/subscriptions/client.go | 4 +- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/apis/record_auth_with_oauth2_redirect_test.go b/apis/record_auth_with_oauth2_redirect_test.go index d9cf77b1..f8cb06ce 100644 --- a/apis/record_auth_with_oauth2_redirect_test.go +++ b/apis/record_auth_with_oauth2_redirect_test.go @@ -56,18 +56,22 @@ func TestRecordAuthWithOAuth2Redirect(t *testing.T) { } } + // note: don't exit because it is usually called as part of a separate goroutine checkClientMessages := func(t testing.TB, clientId string, msg subscriptions.Message, expectedMessages map[string][]string) { if len(expectedMessages[clientId]) == 0 { - t.Fatalf("Unexpected client %q message, got %s:\n%s", clientId, msg.Name, msg.Data) + t.Errorf("Unexpected client %q message, got %q:\n%q", clientId, msg.Name, msg.Data) + return } if msg.Name != "@oauth2" { - t.Fatalf("Expected @oauth2 msg.Name, got %q", msg.Name) + t.Errorf("Expected @oauth2 msg.Name, got %q", msg.Name) + return } for _, txt := range expectedMessages[clientId] { if !strings.Contains(string(msg.Data), txt) { - t.Fatalf("Failed to find %q in \n%s", txt, msg.Data) + t.Errorf("Failed to find %q in \n%s", txt, msg.Data) + return } } } @@ -91,19 +95,37 @@ func TestRecordAuthWithOAuth2Redirect(t *testing.T) { for { select { - case msg := <-clients["c1"].Channel(): - checkClientMessages(t, "c1", msg, expectedMessages) - case msg := <-clients["c2"].Channel(): - checkClientMessages(t, "c2", msg, expectedMessages) - case msg := <-clients["c3"].Channel(): - checkClientMessages(t, "c3", msg, expectedMessages) - case msg := <-clients["c4"].Channel(): - checkClientMessages(t, "c4", msg, expectedMessages) - case msg := <-clients["c5"].Channel(): - checkClientMessages(t, "c5", msg, expectedMessages) + case msg, ok := <-clients["c1"].Channel(): + if ok { + checkClientMessages(t, "c1", msg, expectedMessages) + } else { + t.Errorf("Unexpected c1 closed channel") + } + case msg, ok := <-clients["c2"].Channel(): + if ok { + checkClientMessages(t, "c2", msg, expectedMessages) + } else { + t.Errorf("Unexpected c2 closed channel") + } + case msg, ok := <-clients["c3"].Channel(): + if ok { + checkClientMessages(t, "c3", msg, expectedMessages) + } else { + t.Errorf("Unexpected c3 closed channel") + } + case msg, ok := <-clients["c4"].Channel(): + if ok { + checkClientMessages(t, "c4", msg, expectedMessages) + } else { + t.Errorf("Unexpected c4 closed channel") + } + case _, ok := <-clients["c5"].Channel(): + if ok { + t.Errorf("Expected c5 channel to be closed") + } case <-ctx.Done(): for _, c := range clients { - close(c.Channel()) + c.Discard() } return } diff --git a/tools/subscriptions/client.go b/tools/subscriptions/client.go index 0bcb3594..e9bc01c8 100644 --- a/tools/subscriptions/client.go +++ b/tools/subscriptions/client.go @@ -263,9 +263,9 @@ func (c *DefaultClient) Discard() { return } - c.isDiscarded = true - close(c.channel) + + c.isDiscarded = true } // IsDiscarded implements the [Client.IsDiscarded] interface method.