From 4df974ccc4c3b733a95b22dc15f92245ed3c9e4c Mon Sep 17 00:00:00 2001 From: Nick Craig-Wood Date: Wed, 25 Jun 2025 21:10:16 +0100 Subject: [PATCH] pacer: fix nil pointer deref in RetryError - fixes #8077 Before this change, if RetryAfterError was called with a nil err, then it's Error method would return this when wrapped in a fmt.Errorf statement error %!v(PANIC=Error method: runtime error: invalid memory address or nil pointer dereference)) Looking at the code, it looks like RetryAfterError will usually be called with a nil pointer, so this patch makes sure it has a sensible error. --- lib/pacer/pacer.go | 11 +++++- lib/pacer/pacer_test.go | 81 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/lib/pacer/pacer.go b/lib/pacer/pacer.go index 67f3ffece..1c7cc51e5 100644 --- a/lib/pacer/pacer.go +++ b/lib/pacer/pacer.go @@ -2,6 +2,8 @@ package pacer import ( + "errors" + "fmt" "sync" "time" @@ -235,15 +237,22 @@ type retryAfterError struct { } func (r *retryAfterError) Error() string { - return r.error.Error() + return fmt.Sprintf("%v: trying again in %v", r.error, r.retryAfter) } func (r *retryAfterError) Cause() error { return r.error } +func (r *retryAfterError) Unwrap() error { + return r.error +} + // RetryAfterError returns a wrapped error that can be used by Calculator implementations func RetryAfterError(err error, retryAfter time.Duration) error { + if err == nil { + err = errors.New("too many requests") + } return &retryAfterError{ error: err, retryAfter: retryAfter, diff --git a/lib/pacer/pacer_test.go b/lib/pacer/pacer_test.go index 144298123..3ac9c3741 100644 --- a/lib/pacer/pacer_test.go +++ b/lib/pacer/pacer_test.go @@ -2,6 +2,8 @@ package pacer import ( "errors" + "fmt" + "strings" "sync" "testing" "time" @@ -350,3 +352,82 @@ func TestCallParallel(t *testing.T) { assert.Equal(t, 5, called) wait.Broadcast() } + +func TestRetryAfterError_NonNilErr(t *testing.T) { + orig := errors.New("test failure") + dur := 2 * time.Second + err := RetryAfterError(orig, dur) + + rErr, ok := err.(*retryAfterError) + if !ok { + t.Fatalf("expected *retryAfterError, got %T", err) + } + if !strings.Contains(err.Error(), "test failure") { + t.Errorf("Error() = %q, want it to contain original message", err.Error()) + } + if !strings.Contains(err.Error(), dur.String()) { + t.Errorf("Error() = %q, want it to contain retryAfter %v", err.Error(), dur) + } + if rErr.retryAfter != dur { + t.Errorf("retryAfter = %v, want %v", rErr.retryAfter, dur) + } + if !errors.Is(err, orig) { + t.Error("errors.Is(err, orig) = false, want true") + } +} + +func TestRetryAfterError_NilErr(t *testing.T) { + dur := 5 * time.Second + err := RetryAfterError(nil, dur) + if !strings.Contains(err.Error(), "too many requests") { + t.Errorf("Error() = %q, want it to mention default message", err.Error()) + } + if !strings.Contains(err.Error(), dur.String()) { + t.Errorf("Error() = %q, want it to contain retryAfter %v", err.Error(), dur) + } +} + +func TestCauseMethod(t *testing.T) { + orig := errors.New("underlying") + dur := time.Second + rErr := RetryAfterError(orig, dur).(*retryAfterError) + cause := rErr.Cause() + if !errors.Is(cause, orig) { + t.Errorf("Cause() does not wrap original: got %v", cause) + } +} + +func TestIsRetryAfter_True(t *testing.T) { + orig := errors.New("oops") + dur := 3 * time.Second + err := RetryAfterError(orig, dur) + + gotDur, ok := IsRetryAfter(err) + if !ok { + t.Error("IsRetryAfter returned false, want true") + } + if gotDur != dur { + t.Errorf("got %v, want %v", gotDur, dur) + } +} + +func TestIsRetryAfter_Nested(t *testing.T) { + orig := errors.New("fail") + dur := 4 * time.Second + retryErr := RetryAfterError(orig, dur) + nested := fmt.Errorf("wrapped: %w", retryErr) + + gotDur, ok := IsRetryAfter(nested) + if !ok { + t.Error("IsRetryAfter on nested error returned false, want true") + } + if gotDur != dur { + t.Errorf("got %v, want %v", gotDur, dur) + } +} + +func TestIsRetryAfter_False(t *testing.T) { + if _, ok := IsRetryAfter(errors.New("other")); ok { + t.Error("IsRetryAfter = true for non-retry error, want false") + } +}