package core import ( "errors" "fmt" "testing" ) func TestGetDefaultRetryInterval(t *testing.T) { t.Parallel() if i := getDefaultRetryInterval(-1); i.Milliseconds() != 1000 { t.Fatalf("Expected 1000ms, got %v", i) } if i := getDefaultRetryInterval(999); i.Milliseconds() != 1000 { t.Fatalf("Expected 1000ms, got %v", i) } if i := getDefaultRetryInterval(3); i.Milliseconds() != 200 { t.Fatalf("Expected 500ms, got %v", i) } } func TestBaseLockRetry(t *testing.T) { t.Parallel() scenarios := []struct { err error failUntilAttempt int expectedAttempts int }{ {nil, 3, 1}, {errors.New("test"), 3, 1}, {errors.New("database is locked"), 3, 3}, } for i, s := range scenarios { t.Run(fmt.Sprintf("%d_%#v", i, s.err), func(t *testing.T) { lastAttempt := 0 err := baseLockRetry(func(attempt int) error { lastAttempt = attempt if attempt < s.failUntilAttempt { return s.err } return nil }, s.failUntilAttempt+2) if lastAttempt != s.expectedAttempts { t.Errorf("Expected lastAttempt to be %d, got %d", s.expectedAttempts, lastAttempt) } if s.failUntilAttempt == s.expectedAttempts && err != nil { t.Fatalf("Expected nil, got err %v", err) } if s.failUntilAttempt != s.expectedAttempts && s.err != nil && err == nil { t.Fatalf("Expected error %q, got nil", s.err) } }) } }