package cookie import ( "errors" "fmt" "net/http" "regexp" "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" pkgcookies "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" ) const ( // Cookies are limited to 4kb for all parts // including the cookie name, value, attributes; IE (http.cookie).String() // Most browsers' max is 4096 -- but we give ourselves some leeway maxCookieLength = 4000 ) // Ensure CookieSessionStore implements the interface var _ sessions.SessionStore = &SessionStore{} // SessionStore is an implementation of the sessions.SessionStore // interface that stores sessions in client side cookies type SessionStore struct { Cookie *options.Cookie CookieCipher encryption.Cipher Minimal bool } // Save takes a sessions.SessionState and stores the information from it // within Cookies set on the HTTP response writer func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { if ss.CreatedAt == nil || ss.CreatedAt.IsZero() { now := time.Now() ss.CreatedAt = &now } value, err := s.cookieForSession(ss) if err != nil { return err } return s.setSessionCookie(rw, req, value, *ss.CreatedAt) } // Load reads sessions.SessionState information from Cookies within the // HTTP request object func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { c, err := loadCookie(req, s.Cookie.Name) if err != nil { // always http.ErrNoCookie return nil, fmt.Errorf("cookie %q not present", s.Cookie.Name) } val, _, ok := encryption.Validate(c, s.Cookie.Secret, s.Cookie.Expire) if !ok { return nil, errors.New("cookie signature not valid") } session, err := sessions.DecodeSessionState(val, s.CookieCipher, true) if err != nil { return nil, err } return session, nil } // Clear clears any saved session information by writing a cookie to // clear the session func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { // matches CookieName, CookieName_ var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.Cookie.Name)) for _, c := range req.Cookies() { if cookieNameRegex.MatchString(c.Name) { clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) http.SetCookie(rw, clearCookie) } } return nil } // cookieForSession serializes a session state for storage in a cookie func (s *SessionStore) cookieForSession(ss *sessions.SessionState) ([]byte, error) { if s.Minimal && (ss.AccessToken != "" || ss.IDToken != "" || ss.RefreshToken != "") { minimal := *ss minimal.AccessToken = "" minimal.IDToken = "" minimal.RefreshToken = "" return minimal.EncodeSessionState(s.CookieCipher, true) } return ss.EncodeSessionState(s.CookieCipher, true) } // setSessionCookie adds the user's session cookie to the response func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val []byte, created time.Time) error { cookies, err := s.makeSessionCookie(req, val, created) if err != nil { return err } for _, c := range cookies { http.SetCookie(rw, c) } return nil } // makeSessionCookie creates an http.Cookie containing the authenticated user's // authentication details func (s *SessionStore) makeSessionCookie(req *http.Request, value []byte, now time.Time) ([]*http.Cookie, error) { strValue := string(value) if strValue != "" { var err error strValue, err = encryption.SignedValue(s.Cookie.Secret, s.Cookie.Name, value, now) if err != nil { return nil, err } } c := s.makeCookie(req, s.Cookie.Name, strValue, s.Cookie.Expire, now) if len(c.String()) > maxCookieLength { return splitCookie(c), nil } return []*http.Cookie{c}, nil } func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { return pkgcookies.MakeCookieFromOptions( req, name, value, s.Cookie, expiration, now, ) } // NewCookieSessionStore initialises a new instance of the SessionStore from // the configuration given func NewCookieSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) (sessions.SessionStore, error) { cipher, err := encryption.NewCFBCipher(encryption.SecretBytes(cookieOpts.Secret)) if err != nil { return nil, fmt.Errorf("error initialising cipher: %v", err) } return &SessionStore{ CookieCipher: cipher, Cookie: cookieOpts, Minimal: opts.Cookie.Minimal, }, nil } // splitCookie reads the full cookie generated to store the session and splits // it into a slice of cookies which fit within the 4kb cookie limit indexing // the cookies from 0 func splitCookie(c *http.Cookie) []*http.Cookie { if len(c.String()) < maxCookieLength { return []*http.Cookie{c} } logger.Errorf("WARNING: Multiple cookies are required for this session as it exceeds the 4kb cookie limit. Please use server side session storage (eg. Redis) instead.") cookies := []*http.Cookie{} valueBytes := []byte(c.Value) count := 0 for len(valueBytes) > 0 { newCookie := copyCookie(c) newCookie.Name = splitCookieName(c.Name, count) count++ newCookie.Value = string(valueBytes) cookieLength := len(newCookie.String()) if cookieLength <= maxCookieLength { valueBytes = []byte{} } else { overflow := cookieLength - maxCookieLength valueSize := len(valueBytes) - overflow newValue := valueBytes[:valueSize] valueBytes = valueBytes[valueSize:] newCookie.Value = string(newValue) } cookies = append(cookies, newCookie) } return cookies } func splitCookieName(name string, count int) string { splitName := fmt.Sprintf("%s_%d", name, count) overflow := len(splitName) - 256 if overflow > 0 { splitName = fmt.Sprintf("%s_%d", name[:len(name)-overflow], count) } return splitName } // loadCookie retreieves the sessions state cookie from the http request. // If a single cookie is present this will be returned, otherwise it attempts // to reconstruct a cookie split up by splitCookie func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) { c, err := req.Cookie(cookieName) if err == nil { return c, nil } cookies := []*http.Cookie{} err = nil count := 0 for err == nil { var c *http.Cookie c, err = req.Cookie(splitCookieName(cookieName, count)) if err == nil { cookies = append(cookies, c) count++ } } if len(cookies) == 0 { return nil, fmt.Errorf("could not find cookie %s", cookieName) } return joinCookies(cookies, cookieName) } // joinCookies takes a slice of cookies from the request and reconstructs the // full session cookie func joinCookies(cookies []*http.Cookie, cookieName string) (*http.Cookie, error) { if len(cookies) == 0 { return nil, fmt.Errorf("list of cookies must be > 0") } if len(cookies) == 1 { return cookies[0], nil } c := copyCookie(cookies[0]) for i := 1; i < len(cookies); i++ { c.Value += cookies[i].Value } c.Name = cookieName return c, nil } func copyCookie(c *http.Cookie) *http.Cookie { return &http.Cookie{ Name: c.Name, Value: c.Value, Path: c.Path, Domain: c.Domain, Expires: c.Expires, RawExpires: c.RawExpires, MaxAge: c.MaxAge, Secure: c.Secure, HttpOnly: c.HttpOnly, Raw: c.Raw, Unparsed: c.Unparsed, SameSite: c.SameSite, } }