1
0
mirror of https://github.com/alexedwards/scs.git synced 2025-07-11 00:50:14 +02:00

Add support for AllCtx() in Iterate()

This commit is contained in:
Alex Edwards
2021-11-26 13:00:03 +01:00
parent 3166777ffc
commit dba928e4fe
3 changed files with 24 additions and 14 deletions

25
data.go
View File

@ -497,20 +497,13 @@ func (s *SessionManager) RememberMe(ctx context.Context, val bool) {
// Iterate retrieves all active (i.e. not expired) sessions from the store and
// executes the provided function fn for each session. If the session store
// being used does not support iteration then Iterate will panic.
func (s *SessionManager) Iterate(fn func(context.Context) error) error {
iterableStore, ok := s.Store.(IterableStore)
if !ok {
panic(fmt.Sprintf("type %T does not implement IterableStore interface", s.Store))
}
allSessions, err := iterableStore.All()
func (s *SessionManager) Iterate(ctx context.Context, fn func(context.Context) error) error {
allSessions, err := s.doStoreAll(ctx)
if err != nil {
return err
}
for token, b := range allSessions {
ctx := context.Background()
sd := &sessionData{
status: Unmodified,
token: token,
@ -620,3 +613,17 @@ func (s *SessionManager) doStoreCommit(ctx context.Context, token string, b []by
}
return s.Store.Commit(token, b, expiry)
}
func (s *SessionManager) doStoreAll(ctx context.Context) (map[string][]byte, error) {
cs, ok := s.Store.(CtxStore)
if ok {
return cs.AllCtx(ctx)
}
is, ok := s.Store.(IterableStore)
if ok {
return is.All()
}
panic(fmt.Sprintf("type %T does not support iteration", s.Store))
}

View File

@ -325,7 +325,7 @@ func TestIterate(t *testing.T) {
results := []string{}
err := sessionManager.Iterate(func(ctx context.Context) error {
err := sessionManager.Iterate(context.Background(), func(ctx context.Context) error {
i := sessionManager.GetString(ctx, "foo")
results = append(results, i)
return nil
@ -341,7 +341,7 @@ func TestIterate(t *testing.T) {
t.Fatalf("unexpected value: got %v", results)
}
err = sessionManager.Iterate(func(ctx context.Context) error {
err = sessionManager.Iterate(context.Background(), func(ctx context.Context) error {
return errors.New("expected error")
})
if err.Error() != "expected error" {

View File

@ -39,12 +39,15 @@ type IterableStore interface {
type CtxStore interface {
Store
// DeleteCtx same as Store.Delete, excepts takes context.Context
// DeleteCtx same as Store.Delete, except it takes a context.Context.
DeleteCtx(ctx context.Context, token string) (err error)
// FindCtx same as Store.Find, excepts takes context.Context
// FindCtx same as Store.Find, except it takes a context.Context.
FindCtx(ctx context.Context, token string) (b []byte, found bool, err error)
// CommitCtx same as Store.Commit, excepts takes context.Context
// CommitCtx same as Store.Commit, except it takes a context.Context.
CommitCtx(ctx context.Context, token string, b []byte, expiry time.Time) (err error)
// AllCtx same as IterableStore.All, expect it takes a context.Context.
AllCtx(ctx context.Context) (map[string][]byte, error)
}