From 10f47e325b782a60b8689653fa45360dee7fbf34 Mon Sep 17 00:00:00 2001
From: Eelco Cramer <eelco@servicelab.org>
Date: Mon, 9 Nov 2015 09:28:34 +0100
Subject: [PATCH] Add Azure Provider

---
 README.md                     |  19 +++++
 main.go                       |   2 +
 options.go                    |  19 +++--
 providers/azure.go            |  86 ++++++++++++++++++++++
 providers/azure_test.go       | 135 ++++++++++++++++++++++++++++++++++
 providers/provider_data.go    |  19 ++---
 providers/provider_default.go |   4 +
 providers/providers.go        |   2 +
 watcher.go                    |   5 +-
 9 files changed, 272 insertions(+), 19 deletions(-)
 create mode 100644 providers/azure.go
 create mode 100644 providers/azure_test.go

diff --git a/README.md b/README.md
index da71e1be..3c0f98fe 100644
--- a/README.md
+++ b/README.md
@@ -29,6 +29,8 @@ You will need to register an OAuth application with a Provider (Google, Github o
 Valid providers are :
 
 * [Google](#google-auth-provider) *default*
+
+* [Azure](#azure-auth-provider)
 * [GitHub](#github-auth-provider)
 * [LinkedIn](#linkedin-auth-provider)
 * [MyUSA](#myusa-auth-provider)
@@ -76,6 +78,15 @@ and the user will be checked against all the provided groups.
 
 Note: The user is checked against the group members list on initial authentication and every time the token is refreshed ( about once an hour ).
 
+### Azure Auth Provider
+
+1. [Add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/) to your Azure Active Directory tenant.
+2. On the App properties page provide the correct Sign-On URL ie `https//internal.yourcompany.com/oauth2/callback`
+3. If applicable take note of your `TenantID` and provide it via the `--azure-tenant=<YOUR TENANT ID>` commandline option. Default the `common` tenant is used.
+
+The Azure AD auth provider uses `openid` as it default scope. It uses `https://graph.windows.net` as a default protected resource. It call to `https://graph.windows.net/me` to get the email address of the user that logs in.
+
+
 ### GitHub Auth Provider
 
 1. Create a new project: https://github.com/settings/developers
@@ -102,6 +113,12 @@ For LinkedIn, the registration steps are:
 
 The [MyUSA](https://alpha.my.usa.gov) authentication service ([GitHub](https://github.com/18F/myusa))
 
+### Microsoft Azure AD Provider
+
+For adding an application to the Microsoft Azure AD follow [these steps to add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/).
+
+Take note of your `TenantId` if applicable for your situation. The `TenantId` can be used to override the default `common` authorization server with a tenant specific server.
+
 ## Email Authentication
 
 To authorize by email domain use `--email-domain=yourcompany.com`. To authorize individual email addresses use `--authenticated-emails-file=/path/to/file` with one email per line. To authorize all email addresse use `--email-domain=*`.
@@ -120,6 +137,7 @@ An example [oauth2_proxy.cfg](contrib/oauth2_proxy.cfg.example) config file is i
 Usage of oauth2_proxy:
   -approval-prompt="force": Oauth approval_prompt
   -authenticated-emails-file="": authenticate against emails via file (one per line)
+  -azure-tenant="common": go to a tenant-specific or common (tenant-independent) endpoint.
   -basic-auth-password="": the password to set when passing the HTTP Basic Auth header
   -client-id="": the OAuth Client ID: ie: "123456.apps.googleusercontent.com"
   -client-secret="": the OAuth Client Secret
@@ -151,6 +169,7 @@ Usage of oauth2_proxy:
   -proxy-prefix="/oauth2": the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)
   -redeem-url="": Token redemption endpoint
   -redirect-url="": the OAuth Redirect URL. ie: "https://internalapp.yourcompany.com/oauth2/callback"
+  -resource="": the resource that is being protected. ie: "https://graph.windows.net". Currently only used in the Azure provider.
   -request-logging=true: Log requests to stdout
   -scope="": Oauth scope specification
   -signature-key="": GAP-Signature request signature key (algorithm:secretkey)
diff --git a/main.go b/main.go
index a8d3f1b6..dd9a100e 100644
--- a/main.go
+++ b/main.go
@@ -38,6 +38,7 @@ func main() {
 	flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)")
 
 	flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email")
+	flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.")
 	flagSet.String("github-org", "", "restrict logins to members of this organisation")
 	flagSet.String("github-team", "", "restrict logins to members of this team")
 	flagSet.Var(&googleGroups, "google-group", "restrict logins to members of this google group (may be given multiple times).")
@@ -65,6 +66,7 @@ func main() {
 	flagSet.String("login-url", "", "Authentication endpoint")
 	flagSet.String("redeem-url", "", "Token redemption endpoint")
 	flagSet.String("profile-url", "", "Profile access endpoint")
+	flagSet.String("resource", "", "The resource that is protected (Azure AD only)")
 	flagSet.String("validate-url", "", "Access token validation endpoint")
 	flagSet.String("scope", "", "OAuth scope specification")
 	flagSet.String("approval-prompt", "force", "OAuth approval_prompt")
diff --git a/options.go b/options.go
index b64396cb..5d4c86f0 100644
--- a/options.go
+++ b/options.go
@@ -25,6 +25,7 @@ type Options struct {
 	TLSKeyFile   string `flag:"tls-key" cfg:"tls_key_file"`
 
 	AuthenticatedEmailsFile  string   `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
+	AzureTenant              string   `flag:"azure-tenant" cfg:"azure_tenant"`
 	EmailDomains             []string `flag:"email-domain" cfg:"email_domains"`
 	GitHubOrg                string   `flag:"github-org" cfg:"github_org"`
 	GitHubTeam               string   `flag:"github-team" cfg:"github_team"`
@@ -52,13 +53,14 @@ type Options struct {
 
 	// These options allow for other providers besides Google, with
 	// potential overrides.
-	Provider       string `flag:"provider" cfg:"provider"`
-	LoginURL       string `flag:"login-url" cfg:"login_url"`
-	RedeemURL      string `flag:"redeem-url" cfg:"redeem_url"`
-	ProfileURL     string `flag:"profile-url" cfg:"profile_url"`
-	ValidateURL    string `flag:"validate-url" cfg:"validate_url"`
-	Scope          string `flag:"scope" cfg:"scope"`
-	ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt"`
+	Provider          string `flag:"provider" cfg:"provider"`
+	LoginURL          string `flag:"login-url" cfg:"login_url"`
+	RedeemURL         string `flag:"redeem-url" cfg:"redeem_url"`
+	ProfileURL        string `flag:"profile-url" cfg:"profile_url"`
+	ProtectedResource string `flag:"resource" cfg:"resource"`
+	ValidateURL       string `flag:"validate-url" cfg:"validate_url"`
+	Scope             string `flag:"scope" cfg:"scope"`
+	ApprovalPrompt    string `flag:"approval-prompt" cfg:"approval_prompt"`
 
 	RequestLogging bool `flag:"request-logging" cfg:"request_logging"`
 
@@ -205,9 +207,12 @@ func parseProviderInfo(o *Options, msgs []string) []string {
 	p.RedeemURL, msgs = parseURL(o.RedeemURL, "redeem", msgs)
 	p.ProfileURL, msgs = parseURL(o.ProfileURL, "profile", msgs)
 	p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
+	p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
 
 	o.provider = providers.New(o.Provider, p)
 	switch p := o.provider.(type) {
+	case *providers.AzureProvider:
+		p.Configure(o.AzureTenant)
 	case *providers.GitHubProvider:
 		p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam)
 	case *providers.GoogleProvider:
diff --git a/providers/azure.go b/providers/azure.go
new file mode 100644
index 00000000..2e8c57d8
--- /dev/null
+++ b/providers/azure.go
@@ -0,0 +1,86 @@
+package providers
+
+import (
+	"errors"
+	"fmt"
+	"github.com/bitly/oauth2_proxy/api"
+	"log"
+	"net/http"
+	"net/url"
+)
+
+type AzureProvider struct {
+	*ProviderData
+	Tenant string
+}
+
+func NewAzureProvider(p *ProviderData) *AzureProvider {
+	p.ProviderName = "Azure"
+
+	if p.ProfileURL == nil || p.ProfileURL.String() == "" {
+		p.ProfileURL = &url.URL{
+			Scheme:   "https",
+			Host:     "graph.windows.net",
+			Path:     "/me",
+			RawQuery: "api-version=1.6",
+		}
+	}
+	if p.ProtectedResource == nil || p.ProtectedResource.String() == "" {
+		p.ProtectedResource = &url.URL{
+			Scheme: "https",
+			Host:   "graph.windows.net",
+		}
+	}
+	if p.Scope == "" {
+		p.Scope = "openid"
+	}
+
+	return &AzureProvider{ProviderData: p}
+}
+
+func (p *AzureProvider) Configure(tenant string) {
+	p.Tenant = tenant
+	if tenant == "" {
+		p.Tenant = "common"
+	}
+
+	if p.LoginURL == nil || p.LoginURL.String() == "" {
+		p.LoginURL = &url.URL{
+			Scheme: "https",
+			Host:   "login.microsoftonline.com",
+			Path:   "/" + p.Tenant + "/oauth2/authorize"}
+	}
+	if p.RedeemURL == nil || p.RedeemURL.String() == "" {
+		p.RedeemURL = &url.URL{
+			Scheme: "https",
+			Host:   "login.microsoftonline.com",
+			Path:   "/" + p.Tenant + "/oauth2/token",
+		}
+	}
+}
+
+func getAzureHeader(access_token string) http.Header {
+	header := make(http.Header)
+	header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
+	return header
+}
+
+func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) {
+	if s.AccessToken == "" {
+		return "", errors.New("missing access token")
+	}
+	req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
+	if err != nil {
+		return "", err
+	}
+	req.Header = getAzureHeader(s.AccessToken)
+
+	json, err := api.Request(req)
+
+	if err != nil {
+		log.Printf("failed making request %s", err)
+		return "", err
+	}
+
+	return json.Get("mail").String()
+}
diff --git a/providers/azure_test.go b/providers/azure_test.go
new file mode 100644
index 00000000..1aa823ac
--- /dev/null
+++ b/providers/azure_test.go
@@ -0,0 +1,135 @@
+package providers
+
+import (
+	"github.com/bmizerany/assert"
+	"net/http"
+	"net/http/httptest"
+	"net/url"
+	"testing"
+)
+
+func testAzureProvider(hostname string) *AzureProvider {
+	p := NewAzureProvider(
+		&ProviderData{
+			ProviderName:      "",
+			LoginURL:          &url.URL{},
+			RedeemURL:         &url.URL{},
+			ProfileURL:        &url.URL{},
+			ValidateURL:       &url.URL{},
+			ProtectedResource: &url.URL{},
+			Scope:             ""})
+	if hostname != "" {
+		updateURL(p.Data().LoginURL, hostname)
+		updateURL(p.Data().RedeemURL, hostname)
+		updateURL(p.Data().ProfileURL, hostname)
+		updateURL(p.Data().ValidateURL, hostname)
+		updateURL(p.Data().ProtectedResource, hostname)
+	}
+	return p
+}
+
+func TestAzureProviderDefaults(t *testing.T) {
+	p := testAzureProvider("")
+	assert.NotEqual(t, nil, p)
+	p.Configure("")
+	assert.Equal(t, "Azure", p.Data().ProviderName)
+	assert.Equal(t, "common", p.Tenant)
+	assert.Equal(t, "https://login.microsoftonline.com/common/oauth2/authorize",
+		p.Data().LoginURL.String())
+	assert.Equal(t, "https://login.microsoftonline.com/common/oauth2/token",
+		p.Data().RedeemURL.String())
+	assert.Equal(t, "https://graph.windows.net/me?api-version=1.6",
+		p.Data().ProfileURL.String())
+	assert.Equal(t, "https://graph.windows.net",
+		p.Data().ProtectedResource.String())
+	assert.Equal(t, "",
+		p.Data().ValidateURL.String())
+	assert.Equal(t, "openid", p.Data().Scope)
+}
+
+func TestAzureProviderOverrides(t *testing.T) {
+	p := NewAzureProvider(
+		&ProviderData{
+			LoginURL: &url.URL{
+				Scheme: "https",
+				Host:   "example.com",
+				Path:   "/oauth/auth"},
+			RedeemURL: &url.URL{
+				Scheme: "https",
+				Host:   "example.com",
+				Path:   "/oauth/token"},
+			ProfileURL: &url.URL{
+				Scheme: "https",
+				Host:   "example.com",
+				Path:   "/oauth/profile"},
+			ValidateURL: &url.URL{
+				Scheme: "https",
+				Host:   "example.com",
+				Path:   "/oauth/tokeninfo"},
+			ProtectedResource: &url.URL{
+				Scheme: "https",
+				Host:   "example.com"},
+			Scope: "profile"})
+	assert.NotEqual(t, nil, p)
+	assert.Equal(t, "Azure", p.Data().ProviderName)
+	assert.Equal(t, "https://example.com/oauth/auth",
+		p.Data().LoginURL.String())
+	assert.Equal(t, "https://example.com/oauth/token",
+		p.Data().RedeemURL.String())
+	assert.Equal(t, "https://example.com/oauth/profile",
+		p.Data().ProfileURL.String())
+	assert.Equal(t, "https://example.com/oauth/tokeninfo",
+		p.Data().ValidateURL.String())
+	assert.Equal(t, "https://example.com",
+		p.Data().ProtectedResource.String())
+	assert.Equal(t, "profile", p.Data().Scope)
+}
+
+func TestAzureSetTenant(t *testing.T) {
+	p := testAzureProvider("")
+	p.Configure("example")
+	assert.Equal(t, "Azure", p.Data().ProviderName)
+	assert.Equal(t, "example", p.Tenant)
+	assert.Equal(t, "https://login.microsoftonline.com/example/oauth2/authorize",
+		p.Data().LoginURL.String())
+	assert.Equal(t, "https://login.microsoftonline.com/example/oauth2/token",
+		p.Data().RedeemURL.String())
+	assert.Equal(t, "https://graph.windows.net/me?api-version=1.6",
+		p.Data().ProfileURL.String())
+	assert.Equal(t, "https://graph.windows.net",
+		p.Data().ProtectedResource.String())
+	assert.Equal(t, "",
+		p.Data().ValidateURL.String())
+	assert.Equal(t, "openid", p.Data().Scope)
+}
+
+func testAzureBackend(payload string) *httptest.Server {
+	path := "/me"
+	query := "api-version=1.6"
+
+	return httptest.NewServer(http.HandlerFunc(
+		func(w http.ResponseWriter, r *http.Request) {
+			url := r.URL
+			if url.Path != path || url.RawQuery != query {
+				w.WriteHeader(404)
+			} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" {
+				w.WriteHeader(403)
+			} else {
+				w.WriteHeader(200)
+				w.Write([]byte(payload))
+			}
+		}))
+}
+
+func TestAzureProviderGetEmailAddress(t *testing.T) {
+	b := testAzureBackend(`{ "mail": "user@windows.net" }`)
+	defer b.Close()
+
+	b_url, _ := url.Parse(b.URL)
+	p := testAzureProvider(b_url.Host)
+
+	session := &SessionState{AccessToken: "imaginary_access_token"}
+	email, err := p.GetEmailAddress(session)
+	assert.Equal(t, nil, err)
+	assert.Equal(t, "user@windows.net", email)
+}
diff --git a/providers/provider_data.go b/providers/provider_data.go
index a13ed8e5..92e27dd7 100644
--- a/providers/provider_data.go
+++ b/providers/provider_data.go
@@ -5,15 +5,16 @@ import (
 )
 
 type ProviderData struct {
-	ProviderName   string
-	ClientID       string
-	ClientSecret   string
-	LoginURL       *url.URL
-	RedeemURL      *url.URL
-	ProfileURL     *url.URL
-	ValidateURL    *url.URL
-	Scope          string
-	ApprovalPrompt string
+	ProviderName      string
+	ClientID          string
+	ClientSecret      string
+	LoginURL          *url.URL
+	RedeemURL         *url.URL
+	ProfileURL        *url.URL
+	ProtectedResource *url.URL
+	ValidateURL       *url.URL
+	Scope             string
+	ApprovalPrompt    string
 }
 
 func (p *ProviderData) Data() *ProviderData { return p }
diff --git a/providers/provider_default.go b/providers/provider_default.go
index 77b3dfdf..82b73ec3 100644
--- a/providers/provider_default.go
+++ b/providers/provider_default.go
@@ -25,6 +25,10 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
 	params.Add("client_secret", p.ClientSecret)
 	params.Add("code", code)
 	params.Add("grant_type", "authorization_code")
+	if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
+		params.Add("resource", p.ProtectedResource.String())
+	}
+
 	var req *http.Request
 	req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
 	if err != nil {
diff --git a/providers/providers.go b/providers/providers.go
index 59e5f9a2..db0fe13b 100644
--- a/providers/providers.go
+++ b/providers/providers.go
@@ -24,6 +24,8 @@ func New(provider string, p *ProviderData) Provider {
 		return NewLinkedInProvider(p)
 	case "github":
 		return NewGitHubProvider(p)
+	case "azure":
+		return NewAzureProvider(p)
 	default:
 		return NewGoogleProvider(p)
 	}
diff --git a/watcher.go b/watcher.go
index c34058b1..bedb9f89 100644
--- a/watcher.go
+++ b/watcher.go
@@ -41,9 +41,8 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) {
 		for {
 			select {
 			case _ = <-done:
-				log.Printf("Shutting down watcher for: %s",
-					filename)
-				return
+				log.Printf("Shutting down watcher for: %s", filename)
+				break
 			case event := <-watcher.Events:
 				// On Arch Linux, it appears Chmod events precede Remove events,
 				// which causes a race between action() and the coming Remove event.