1
0
mirror of https://github.com/alexedwards/scs.git synced 2025-07-17 01:12:21 +02:00

Add All() to mockstore and fix typos

This commit is contained in:
gandaldf
2021-11-22 00:07:25 +01:00
parent 2068e2a219
commit 49e7f82fa9
2 changed files with 80 additions and 6 deletions

View File

@ -24,17 +24,22 @@ type expectedCommit struct {
returnErr error returnErr error
} }
type expectedAll struct {
returnMB map[string][]byte
returnErr error
}
type MockStore struct { type MockStore struct {
deleteExpectations []expectedDelete deleteExpectations []expectedDelete
findExpectations []expectedFind findExpectations []expectedFind
commitExpectations []expectedCommit commitExpectations []expectedCommit
allExpectations []expectedAll
} }
// Delete implements the Store interface func (m *MockStore) ExpectDelete(token string, err error) {
func (m *MockStore) ExpectDelete(token string, returnErr error) {
m.deleteExpectations = append(m.deleteExpectations, expectedDelete{ m.deleteExpectations = append(m.deleteExpectations, expectedDelete{
inputToken: token, inputToken: token,
returnErr: returnErr, returnErr: err,
}) })
} }
@ -109,7 +114,7 @@ func (m *MockStore) Commit(token string, b []byte, expiry time.Time) (err error)
expectationFound bool expectationFound bool
) )
for i, expectation := range m.commitExpectations { for i, expectation := range m.commitExpectations {
if expectation.inputToken == token && bytes.Compare(expectation.inputB, b) == 0 && expectation.inputExpiry == expiry { if expectation.inputToken == token && bytes.Equal(expectation.inputB, b) && expectation.inputExpiry == expiry {
indexToRemove = i indexToRemove = i
expectationFound = true expectationFound = true
break break
@ -124,3 +129,33 @@ func (m *MockStore) Commit(token string, b []byte, expiry time.Time) (err error)
return errToReturn return errToReturn
} }
func (m *MockStore) ExpectAll(mb map[string][]byte, err error) {
m.allExpectations = append(m.allExpectations, expectedAll{
returnMB: mb,
returnErr: err,
})
}
// All implements the IterableStore interface
func (m *MockStore) All() (map[string][]byte, error) {
var (
indexToRemove int
expectationFound bool
)
for i, expectation := range m.allExpectations {
if len(expectation.returnMB) == 3 {
indexToRemove = i
expectationFound = true
break
}
}
if !expectationFound {
panic("store.All called unexpectedly")
}
valueToReturn := m.allExpectations[indexToRemove]
m.allExpectations = m.allExpectations[:indexToRemove+copy(m.allExpectations[indexToRemove:], m.allExpectations[indexToRemove+1:])]
return valueToReturn.returnMB, valueToReturn.returnErr
}

View File

@ -3,6 +3,7 @@ package mockstore
import ( import (
"bytes" "bytes"
"errors" "errors"
"reflect"
"testing" "testing"
"time" "time"
) )
@ -56,10 +57,10 @@ func TestMockStore_Find(T *testing.T) {
s.ExpectFind(exampleToken, expectedBytes, expectedFound, nil) s.ExpectFind(exampleToken, expectedBytes, expectedFound, nil)
actualBytes, actualFound, actualErr := s.Find(exampleToken) actualBytes, actualFound, actualErr := s.Find(exampleToken)
if !bytes.Equal(expectedBytes, actualBytes) { if !bytes.Equal(actualBytes, expectedBytes) {
t.Error("returned bytes do not match expectation") t.Error("returned bytes do not match expectation")
} }
if expectedFound != actualFound { if actualFound != expectedFound {
t.Error("returned found does not match expectation") t.Error("returned found does not match expectation")
} }
if actualErr != nil { if actualErr != nil {
@ -127,3 +128,41 @@ func TestMockStore_Commit(T *testing.T) {
} }
}) })
} }
func TestMockStore_All(T *testing.T) {
T.Parallel()
T.Run("obligatory", func(t *testing.T) {
s := &MockStore{}
expectedMapBytes := map[string][]byte{"token1": []byte("hello, world 1!"), "token2": []byte("hello, world 2!"), "token3": []byte("hello, world 3!")}
s.ExpectAll(expectedMapBytes, nil)
actualMapBytes, actualErr := s.All()
if !reflect.DeepEqual(actualMapBytes, expectedMapBytes) {
t.Error("returned map bytes do not match expectation")
}
if actualErr != nil {
t.Error("unexpected error returned")
}
if len(s.allExpectations) != 0 {
t.Error("expectations left over after exhausting calls")
}
})
T.Run("panics with unfound expectation", func(t *testing.T) {
s := &MockStore{}
defer func() {
if r := recover(); r == nil {
t.Error("expected panic to occur")
}
}()
_, actualErr := s.All()
if actualErr != nil {
t.Error("unexpected error returned")
}
})
}