From 49e7f82fa9f95572b641dfc2573163600d8c6b5a Mon Sep 17 00:00:00 2001 From: gandaldf Date: Mon, 22 Nov 2021 00:07:25 +0100 Subject: [PATCH] Add All() to mockstore and fix typos --- mockstore/store.go | 43 +++++++++++++++++++++++++++++++++++++---- mockstore/store_test.go | 43 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/mockstore/store.go b/mockstore/store.go index e51b672..205c8a7 100644 --- a/mockstore/store.go +++ b/mockstore/store.go @@ -24,17 +24,22 @@ type expectedCommit struct { returnErr error } +type expectedAll struct { + returnMB map[string][]byte + returnErr error +} + type MockStore struct { deleteExpectations []expectedDelete findExpectations []expectedFind commitExpectations []expectedCommit + allExpectations []expectedAll } -// Delete implements the Store interface -func (m *MockStore) ExpectDelete(token string, returnErr error) { +func (m *MockStore) ExpectDelete(token string, err error) { m.deleteExpectations = append(m.deleteExpectations, expectedDelete{ inputToken: token, - returnErr: returnErr, + returnErr: err, }) } @@ -109,7 +114,7 @@ func (m *MockStore) Commit(token string, b []byte, expiry time.Time) (err error) expectationFound bool ) for i, expectation := range m.commitExpectations { - if expectation.inputToken == token && bytes.Compare(expectation.inputB, b) == 0 && expectation.inputExpiry == expiry { + if expectation.inputToken == token && bytes.Equal(expectation.inputB, b) && expectation.inputExpiry == expiry { indexToRemove = i expectationFound = true break @@ -124,3 +129,33 @@ func (m *MockStore) Commit(token string, b []byte, expiry time.Time) (err error) return errToReturn } + +func (m *MockStore) ExpectAll(mb map[string][]byte, err error) { + m.allExpectations = append(m.allExpectations, expectedAll{ + returnMB: mb, + returnErr: err, + }) +} + +// All implements the IterableStore interface +func (m *MockStore) All() (map[string][]byte, error) { + var ( + indexToRemove int + expectationFound bool + ) + for i, expectation := range m.allExpectations { + if len(expectation.returnMB) == 3 { + indexToRemove = i + expectationFound = true + break + } + } + if !expectationFound { + panic("store.All called unexpectedly") + } + + valueToReturn := m.allExpectations[indexToRemove] + m.allExpectations = m.allExpectations[:indexToRemove+copy(m.allExpectations[indexToRemove:], m.allExpectations[indexToRemove+1:])] + + return valueToReturn.returnMB, valueToReturn.returnErr +} diff --git a/mockstore/store_test.go b/mockstore/store_test.go index e3721ac..a8b30e0 100644 --- a/mockstore/store_test.go +++ b/mockstore/store_test.go @@ -3,6 +3,7 @@ package mockstore import ( "bytes" "errors" + "reflect" "testing" "time" ) @@ -56,10 +57,10 @@ func TestMockStore_Find(T *testing.T) { s.ExpectFind(exampleToken, expectedBytes, expectedFound, nil) actualBytes, actualFound, actualErr := s.Find(exampleToken) - if !bytes.Equal(expectedBytes, actualBytes) { + if !bytes.Equal(actualBytes, expectedBytes) { t.Error("returned bytes do not match expectation") } - if expectedFound != actualFound { + if actualFound != expectedFound { t.Error("returned found does not match expectation") } if actualErr != nil { @@ -127,3 +128,41 @@ func TestMockStore_Commit(T *testing.T) { } }) } + +func TestMockStore_All(T *testing.T) { + T.Parallel() + + T.Run("obligatory", func(t *testing.T) { + s := &MockStore{} + + expectedMapBytes := map[string][]byte{"token1": []byte("hello, world 1!"), "token2": []byte("hello, world 2!"), "token3": []byte("hello, world 3!")} + + s.ExpectAll(expectedMapBytes, nil) + + actualMapBytes, actualErr := s.All() + if !reflect.DeepEqual(actualMapBytes, expectedMapBytes) { + t.Error("returned map bytes do not match expectation") + } + if actualErr != nil { + t.Error("unexpected error returned") + } + if len(s.allExpectations) != 0 { + t.Error("expectations left over after exhausting calls") + } + }) + + T.Run("panics with unfound expectation", func(t *testing.T) { + s := &MockStore{} + + defer func() { + if r := recover(); r == nil { + t.Error("expected panic to occur") + } + }() + + _, actualErr := s.All() + if actualErr != nil { + t.Error("unexpected error returned") + } + }) +}