You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-12-01 22:51:45 +02:00
- Save both user and email in session state:
Encoding/decoding methods save both email and user
field in session state, for use cases when User is not derived from
email's local-parth, like for GitHub provider.
For retrocompatibility, if no user is obtained by the provider,
(e.g. User is an empty string) the encoding/decoding methods fall back
to the previous behavior and use the email's local-part
Updated also related tests and added two more tests to show behavior
when session contains a non-empty user value.
- Added first basic GitHub provider tests
- Added GetUserName method to Provider interface
The new GetUserName method is intended to return the User
value when this is not the email's local-part.
Added also the default implementation to provider_default.go
- Added call to GetUserName in redeemCode
the new GetUserName method is used in redeemCode
to get SessionState User value.
For backward compatibility, if GetUserName error is
"not implemented", the error is ignored.
- Added GetUserName method and tests to github provider.
280 lines
6.5 KiB
Go
280 lines
6.5 KiB
Go
package providers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"strings"
|
|
)
|
|
|
|
type GitHubProvider struct {
|
|
*ProviderData
|
|
Org string
|
|
Team string
|
|
}
|
|
|
|
func NewGitHubProvider(p *ProviderData) *GitHubProvider {
|
|
p.ProviderName = "GitHub"
|
|
if p.LoginURL == nil || p.LoginURL.String() == "" {
|
|
p.LoginURL = &url.URL{
|
|
Scheme: "https",
|
|
Host: "github.com",
|
|
Path: "/login/oauth/authorize",
|
|
}
|
|
}
|
|
if p.RedeemURL == nil || p.RedeemURL.String() == "" {
|
|
p.RedeemURL = &url.URL{
|
|
Scheme: "https",
|
|
Host: "github.com",
|
|
Path: "/login/oauth/access_token",
|
|
}
|
|
}
|
|
// ValidationURL is the API Base URL
|
|
if p.ValidateURL == nil || p.ValidateURL.String() == "" {
|
|
p.ValidateURL = &url.URL{
|
|
Scheme: "https",
|
|
Host: "api.github.com",
|
|
Path: "/",
|
|
}
|
|
}
|
|
if p.Scope == "" {
|
|
p.Scope = "user:email"
|
|
}
|
|
return &GitHubProvider{ProviderData: p}
|
|
}
|
|
func (p *GitHubProvider) SetOrgTeam(org, team string) {
|
|
p.Org = org
|
|
p.Team = team
|
|
if org != "" || team != "" {
|
|
p.Scope += " read:org"
|
|
}
|
|
}
|
|
|
|
func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
|
|
// https://developer.github.com/v3/orgs/#list-your-organizations
|
|
|
|
var orgs []struct {
|
|
Login string `json:"login"`
|
|
}
|
|
|
|
params := url.Values{
|
|
"limit": {"100"},
|
|
}
|
|
|
|
endpoint := &url.URL{
|
|
Scheme: p.ValidateURL.Scheme,
|
|
Host: p.ValidateURL.Host,
|
|
Path: path.Join(p.ValidateURL.Path, "/user/orgs"),
|
|
RawQuery: params.Encode(),
|
|
}
|
|
req, _ := http.NewRequest("GET", endpoint.String(), nil)
|
|
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
|
req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken))
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if resp.StatusCode != 200 {
|
|
return false, fmt.Errorf(
|
|
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &orgs); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
var presentOrgs []string
|
|
for _, org := range orgs {
|
|
if p.Org == org.Login {
|
|
log.Printf("Found Github Organization: %q", org.Login)
|
|
return true, nil
|
|
}
|
|
presentOrgs = append(presentOrgs, org.Login)
|
|
}
|
|
|
|
log.Printf("Missing Organization:%q in %v", p.Org, presentOrgs)
|
|
return false, nil
|
|
}
|
|
|
|
func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
|
|
// https://developer.github.com/v3/orgs/teams/#list-user-teams
|
|
|
|
var teams []struct {
|
|
Name string `json:"name"`
|
|
Slug string `json:"slug"`
|
|
Org struct {
|
|
Login string `json:"login"`
|
|
} `json:"organization"`
|
|
}
|
|
|
|
params := url.Values{
|
|
"limit": {"100"},
|
|
}
|
|
|
|
endpoint := &url.URL{
|
|
Scheme: p.ValidateURL.Scheme,
|
|
Host: p.ValidateURL.Host,
|
|
Path: path.Join(p.ValidateURL.Path, "/user/teams"),
|
|
RawQuery: params.Encode(),
|
|
}
|
|
req, _ := http.NewRequest("GET", endpoint.String(), nil)
|
|
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
|
req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken))
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if resp.StatusCode != 200 {
|
|
return false, fmt.Errorf(
|
|
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &teams); err != nil {
|
|
return false, fmt.Errorf("%s unmarshaling %s", err, body)
|
|
}
|
|
|
|
var hasOrg bool
|
|
presentOrgs := make(map[string]bool)
|
|
var presentTeams []string
|
|
for _, team := range teams {
|
|
presentOrgs[team.Org.Login] = true
|
|
if p.Org == team.Org.Login {
|
|
hasOrg = true
|
|
ts := strings.Split(p.Team, ",")
|
|
for _, t := range ts {
|
|
if t == team.Slug {
|
|
log.Printf("Found Github Organization:%q Team:%q (Name:%q)", team.Org.Login, team.Slug, team.Name)
|
|
return true, nil
|
|
}
|
|
}
|
|
presentTeams = append(presentTeams, team.Slug)
|
|
}
|
|
}
|
|
if hasOrg {
|
|
log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams)
|
|
} else {
|
|
var allOrgs []string
|
|
for org, _ := range presentOrgs {
|
|
allOrgs = append(allOrgs, org)
|
|
}
|
|
log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs)
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
|
|
|
var emails []struct {
|
|
Email string `json:"email"`
|
|
Primary bool `json:"primary"`
|
|
}
|
|
|
|
// if we require an Org or Team, check that first
|
|
if p.Org != "" {
|
|
if p.Team != "" {
|
|
if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok {
|
|
return "", err
|
|
}
|
|
} else {
|
|
if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok {
|
|
return "", err
|
|
}
|
|
}
|
|
}
|
|
|
|
endpoint := &url.URL{
|
|
Scheme: p.ValidateURL.Scheme,
|
|
Host: p.ValidateURL.Host,
|
|
Path: path.Join(p.ValidateURL.Path, "/user/emails"),
|
|
}
|
|
req, _ := http.NewRequest("GET", endpoint.String(), nil)
|
|
req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
return "", fmt.Errorf("got %d from %q %s",
|
|
resp.StatusCode, endpoint.String(), body)
|
|
}
|
|
|
|
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
|
|
|
if err := json.Unmarshal(body, &emails); err != nil {
|
|
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
|
}
|
|
|
|
for _, email := range emails {
|
|
if email.Primary {
|
|
return email.Email, nil
|
|
}
|
|
}
|
|
|
|
return "", nil
|
|
}
|
|
|
|
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
|
|
var user struct {
|
|
Login string `json:"login"`
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
endpoint := &url.URL{
|
|
Scheme: p.ValidateURL.Scheme,
|
|
Host: p.ValidateURL.Host,
|
|
Path: path.Join(p.ValidateURL.Path, "/user"),
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", endpoint.String(), nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("could not create new GET request: %v", err)
|
|
}
|
|
|
|
req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
defer resp.Body.Close()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
return "", fmt.Errorf("got %d from %q %s",
|
|
resp.StatusCode, endpoint.String(), body)
|
|
}
|
|
|
|
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
|
|
|
if err := json.Unmarshal(body, &user); err != nil {
|
|
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
|
}
|
|
|
|
return user.Login, nil
|
|
}
|