diff --git a/oauthproxy.go b/oauthproxy.go index e64ffe91..653701d9 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -357,22 +357,24 @@ func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessio if err != nil { return nil, err } + return s, nil +} +func (p *OAuthProxy) enrichSession(ctx context.Context, s *sessionsapi.SessionState) error { + var err error if s.Email == "" { s.Email, err = p.provider.GetEmailAddress(ctx, s) if err != nil && err.Error() != "not implemented" { - return nil, err + return err } } - if s.User == "" { s.User, err = p.provider.GetUserName(ctx, s) if err != nil && err.Error() != "not implemented" { - return nil, err + return err } } - - return s, nil + return nil } // MakeCSRFCookie creates a cookie for CSRF @@ -829,14 +831,21 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - s := strings.SplitN(req.Form.Get("state"), ":", 2) - if len(s) != 2 { + err = p.enrichSession(req.Context(), session) + if err != nil { + logger.Errorf("Error creating session during OAuth2 callback: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") + return + } + + state := strings.SplitN(req.Form.Get("state"), ":", 2) + if len(state) != 2 { logger.Error("Error while parsing OAuth2 state: invalid length") p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Invalid State") return } - nonce := s[0] - redirect := s[1] + nonce := state[0] + redirect := state[1] c, err := req.Cookie(p.CSRFCookieName) if err != nil { logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie")