From dba928e4fe6e1c3a7f597652ef65ec836a7fe30a Mon Sep 17 00:00:00 2001 From: Alex Edwards Date: Fri, 26 Nov 2021 13:00:03 +0100 Subject: [PATCH] Add support for AllCtx() in Iterate() --- data.go | 25 ++++++++++++++++--------- session_test.go | 4 ++-- store.go | 9 ++++++--- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/data.go b/data.go index 99ccf08..c467eea 100644 --- a/data.go +++ b/data.go @@ -497,20 +497,13 @@ func (s *SessionManager) RememberMe(ctx context.Context, val bool) { // Iterate retrieves all active (i.e. not expired) sessions from the store and // executes the provided function fn for each session. If the session store // being used does not support iteration then Iterate will panic. -func (s *SessionManager) Iterate(fn func(context.Context) error) error { - iterableStore, ok := s.Store.(IterableStore) - if !ok { - panic(fmt.Sprintf("type %T does not implement IterableStore interface", s.Store)) - } - - allSessions, err := iterableStore.All() +func (s *SessionManager) Iterate(ctx context.Context, fn func(context.Context) error) error { + allSessions, err := s.doStoreAll(ctx) if err != nil { return err } for token, b := range allSessions { - ctx := context.Background() - sd := &sessionData{ status: Unmodified, token: token, @@ -620,3 +613,17 @@ func (s *SessionManager) doStoreCommit(ctx context.Context, token string, b []by } return s.Store.Commit(token, b, expiry) } + +func (s *SessionManager) doStoreAll(ctx context.Context) (map[string][]byte, error) { + cs, ok := s.Store.(CtxStore) + if ok { + return cs.AllCtx(ctx) + } + + is, ok := s.Store.(IterableStore) + if ok { + return is.All() + } + + panic(fmt.Sprintf("type %T does not support iteration", s.Store)) +} diff --git a/session_test.go b/session_test.go index 59341bd..f702cf7 100644 --- a/session_test.go +++ b/session_test.go @@ -325,7 +325,7 @@ func TestIterate(t *testing.T) { results := []string{} - err := sessionManager.Iterate(func(ctx context.Context) error { + err := sessionManager.Iterate(context.Background(), func(ctx context.Context) error { i := sessionManager.GetString(ctx, "foo") results = append(results, i) return nil @@ -341,7 +341,7 @@ func TestIterate(t *testing.T) { t.Fatalf("unexpected value: got %v", results) } - err = sessionManager.Iterate(func(ctx context.Context) error { + err = sessionManager.Iterate(context.Background(), func(ctx context.Context) error { return errors.New("expected error") }) if err.Error() != "expected error" { diff --git a/store.go b/store.go index 00dc2bc..4fefaf9 100644 --- a/store.go +++ b/store.go @@ -39,12 +39,15 @@ type IterableStore interface { type CtxStore interface { Store - // DeleteCtx same as Store.Delete, excepts takes context.Context + // DeleteCtx same as Store.Delete, except it takes a context.Context. DeleteCtx(ctx context.Context, token string) (err error) - // FindCtx same as Store.Find, excepts takes context.Context + // FindCtx same as Store.Find, except it takes a context.Context. FindCtx(ctx context.Context, token string) (b []byte, found bool, err error) - // CommitCtx same as Store.Commit, excepts takes context.Context + // CommitCtx same as Store.Commit, except it takes a context.Context. CommitCtx(ctx context.Context, token string, b []byte, expiry time.Time) (err error) + + // AllCtx same as IterableStore.All, expect it takes a context.Context. + AllCtx(ctx context.Context) (map[string][]byte, error) }