package retry import ( "context" "errors" "testing" "time" ) var errExpectedFailure = errors.New("expected failure for test purposes") func TestDelayer(t *testing.T) { delays := []time.Duration{ time.Millisecond, 2 * time.Millisecond, 4 * time.Millisecond, 10 * time.Millisecond, } tt := []struct { desc string numRetries int expDelay time.Duration }{ {"first try", 0, time.Millisecond}, {"second try", 1, 2 * time.Millisecond}, {"len(delays) try", len(delays) - 1, delays[len(delays)-1]}, {"len(delays) + 1 try", len(delays), delays[len(delays)-1]}, {"len(delays) * 2 try", len(delays) * 2, delays[len(delays)-1]}, } for _, tc := range tt { t.Run(tc.desc, func(t *testing.T) { var ( d = delayer{Delays: delays} delay time.Duration ) for i := tc.numRetries + 1; i > 0; i-- { delay = d.Delay() } if delay != tc.expDelay { t.Fatalf( "expected delay of %s after %d retries, but got %s", tc.expDelay, tc.numRetries, delay) } }) } } func TestRetry(t *testing.T) { delays := []time.Duration{ time.Millisecond, 2 * time.Millisecond, 3 * time.Millisecond, } tt := []struct { desc string tries int success bool err error }{ {"first try", 1, true, nil}, {"second try error", 2, false, errExpectedFailure}, {"third try success", 3, true, nil}, } for _, tc := range tt { t.Run(tc.desc, func(t *testing.T) { tries := 0 retryFunc := func() (bool, error) { tries++ if tries == tc.tries { return tc.success, tc.err } t.Logf("try #%d unsuccessful: trying again up to %d times", tries, tc.tries) return false, nil } err := Retry(context.Background(), delays, retryFunc) if err != tc.err { t.Fatalf("expected error %s, but got error %s", err, tc.err) } if tries != tc.tries { t.Fatalf("expected %d tries, but tried %d times", tc.tries, tries) } }) } }