diff --git a/remote/bitbucketserver/bitbucketserver.go b/remote/bitbucketserver/bitbucketserver.go new file mode 100644 index 000000000..a931f5bb1 --- /dev/null +++ b/remote/bitbucketserver/bitbucketserver.go @@ -0,0 +1,345 @@ +package bitbucketserver + +// Requires the following to be set +// REMOTE_DRIVER=bitbucketserver +// REMOTE_CONFIG=https://{servername}?consumer_key={key added on the stash server for oath1}&git_username={username for clone}&git_password={password for clone}&consumer_rsa=/path/to/pem.file&open={not used yet} +// Configure application links in the bitbucket server -- +// application url needs to be the base url to drone +// incoming auth needs to have the consumer key (same as the key in REMOTE_CONFIG) +// set the public key (public key from the private key added to /var/lib/bitbucketserver/private_key.pem name matters) +// consumer call back is the base url to drone plus /authorize/ +// Needs a pem private key added to /var/lib/bitbucketserver/private_key.pem +// After that you should be good to go + +import ( + "encoding/json" + "fmt" + log "github.com/Sirupsen/logrus" + "github.com/drone/drone/model" + "github.com/mrjones/oauth" + "io/ioutil" + "net/http" + "net/url" + "strconv" +) + +type BitbucketServer struct { + URL string + ConsumerKey string + GitUserName string + GitPassword string + ConsumerRSA string + Open bool + Consumer oauth.Consumer +} + +func Load(config string) *BitbucketServer { + + url_, err := url.Parse(config) + if err != nil { + log.Fatalln("unable to parse remote dsn. %s", err) + } + params := url_.Query() + url_.Path = "" + url_.RawQuery = "" + + bitbucketserver := BitbucketServer{} + bitbucketserver.URL = url_.String() + bitbucketserver.GitUserName = params.Get("git_username") + if bitbucketserver.GitUserName == "" { + log.Fatalln("Must have a git_username") + } + bitbucketserver.GitPassword = params.Get("git_password") + if bitbucketserver.GitPassword == "" { + log.Fatalln("Must have a git_password") + } + bitbucketserver.ConsumerKey = params.Get("consumer_key") + if bitbucketserver.ConsumerKey == "" { + log.Fatalln("Must have a consumer_key") + } + bitbucketserver.ConsumerRSA = params.Get("consumer_rsa") + if bitbucketserver.ConsumerRSA == "" { + log.Fatalln("Must have a consumer_rsa") + } + + bitbucketserver.Open, _ = strconv.ParseBool(params.Get("open")) + + bitbucketserver.Consumer = *NewClient(bitbucketserver.ConsumerRSA, bitbucketserver.ConsumerKey, bitbucketserver.URL) + + return &bitbucketserver +} + +func (bs *BitbucketServer) Login(res http.ResponseWriter, req *http.Request) (*model.User, bool, error) { + log.Info("Starting to login for bitbucketServer") + + log.Info("getting the requestToken") + requestToken, url, err := bs.Consumer.GetRequestTokenAndUrl("oob") + if err != nil { + log.Error(err) + } + + var code = req.FormValue("oauth_verifier") + if len(code) == 0 { + log.Info("redirecting to %s", url) + http.Redirect(res, req, url, http.StatusSeeOther) + return nil, false, nil + } + + var request_oauth_token = req.FormValue("oauth_token") + requestToken.Token = request_oauth_token + accessToken, err := bs.Consumer.AuthorizeToken(requestToken, code) + if err != nil { + log.Error(err) + } + + client, err := bs.Consumer.MakeHttpClient(accessToken) + if err != nil { + log.Error(err) + } + + response, err := client.Get(fmt.Sprintf("%s/plugins/servlet/applinks/whoami", bs.URL)) + if err != nil { + log.Error(err) + } + defer response.Body.Close() + bits, err := ioutil.ReadAll(response.Body) + userName := string(bits) + + response1, err := client.Get(fmt.Sprintf("%s/rest/api/1.0/users/%s",bs.URL, userName)) + contents, err := ioutil.ReadAll(response1.Body) + defer response1.Body.Close() + var mUser User + json.Unmarshal(contents, &mUser) + + user := model.User{} + user.Login = userName + user.Email = mUser.EmailAddress + user.Token = accessToken.Token + + user.Avatar = avatarLink(mUser.EmailAddress) + + return &user, bs.Open, nil +} + +func (bs *BitbucketServer) Auth(token, secret string) (string, error) { + log.Info("Staring to auth for bitbucketServer. %s", token) + if len(token) == 0 { + return "", fmt.Errorf("Hasn't logged in yet") + } + return token, nil +} + +func (bs *BitbucketServer) Repo(u *model.User, owner, name string) (*model.Repo, error) { + log.Info("Staring repo for bitbucketServer with user " + u.Login + " " + owner + " " + name) + + client := NewClientWithToken(&bs.Consumer, u.Token) + + url := fmt.Sprintf("%s/rest/api/projects/%s/repos/%s",bs.URL,owner,name) + log.Info("Trying to get " + url) + response, err := client.Get(url) + if err != nil { + log.Error(err) + } + defer response.Body.Close() + contents, err := ioutil.ReadAll(response.Body) + bsRepo := BSRepo{} + json.Unmarshal(contents, &bsRepo) + + cloneLink := "" + repoLink := "" + + for _, item := range bsRepo.Links.Clone { + if item.Name == "http" { + cloneLink = item.Href + } + } + for _, item := range bsRepo.Links.Self { + if item.Href != "" { + repoLink = item.Href + } + } + //TODO: get the real allow tag+ infomration + repo := &model.Repo{} + repo.Clone = cloneLink + repo.Link = repoLink + repo.Name = bsRepo.Slug + repo.Owner = bsRepo.Project.Key + repo.AllowPush = true + repo.FullName = fmt.Sprintf("%s/%s",bsRepo.Project.Key,bsRepo.Slug) + repo.Branch = "master" + repo.Kind = model.RepoGit + + return repo, nil +} + +func (bs *BitbucketServer) Repos(u *model.User) ([]*model.RepoLite, error) { + log.Info("Staring repos for bitbucketServer " + u.Login) + var repos = []*model.RepoLite{} + + client := NewClientWithToken(&bs.Consumer, u.Token) + + response, err := client.Get(fmt.Sprintf("%s/rest/api/1.0/repos?limit=10000",bs.URL)) + if err != nil { + log.Error(err) + } + defer response.Body.Close() + contents, err := ioutil.ReadAll(response.Body) + var repoResponse Repos + json.Unmarshal(contents, &repoResponse) + + for _, repo := range repoResponse.Values { + repos = append(repos, &model.RepoLite{ + Name: repo.Slug, + FullName: repo.Project.Key + "/" + repo.Slug, + Owner: repo.Project.Key, + }) + } + + return repos, nil +} + +func (bs *BitbucketServer) Perm(u *model.User, owner, repo string) (*model.Perm, error) { + + //TODO: find the real permissions + log.Info("Staring perm for bitbucketServer") + perms := new(model.Perm) + perms.Pull = true + perms.Admin = true + perms.Push = true + return perms, nil +} + +func (bs *BitbucketServer) File(u *model.User, r *model.Repo, b *model.Build, f string) ([]byte, error) { + log.Info(fmt.Sprintf("Staring file for bitbucketServer login: %s repo: %s buildevent: %s string: %s", u.Login, r.Name, b.Event, f)) + + client := NewClientWithToken(&bs.Consumer, u.Token) + fileURL := fmt.Sprintf("%s/projects/%s/repos/%s/browse/%s?raw", bs.URL, r.Owner, r.Name, f) + log.Info(fileURL) + response, err := client.Get(fileURL) + if err != nil { + log.Error(err) + } + if response.StatusCode == 404 { + return nil, nil + } + defer response.Body.Close() + responseBytes, err := ioutil.ReadAll(response.Body) + if err != nil { + log.Error(err) + } + + return responseBytes, nil +} + +func (bs *BitbucketServer) Status(u *model.User, r *model.Repo, b *model.Build, link string) error { + log.Info("Staring status for bitbucketServer") + return nil +} + +func (bs *BitbucketServer) Netrc(user *model.User, r *model.Repo) (*model.Netrc, error) { + log.Info("Starting the Netrc lookup") + u, err := url.Parse(bs.URL) + if err != nil { + return nil, err + } + return &model.Netrc{ + Machine: u.Host, + Login: bs.GitUserName, + Password: bs.GitPassword, + }, nil +} + +func (bs *BitbucketServer) Activate(u *model.User, r *model.Repo, k *model.Key, link string) error { + log.Info(fmt.Sprintf("Staring activate for bitbucketServer user: %s repo: %s key: %s link: %s", u.Login, r.Name, k, link)) + client := NewClientWithToken(&bs.Consumer, u.Token) + hook, err := bs.CreateHook(client, r.Owner, r.Name, "com.atlassian.stash.plugin.stash-web-post-receive-hooks-plugin:postReceiveHook", link) + if err != nil { + return err + } + log.Info(hook) + return nil +} + +func (bs *BitbucketServer) Deactivate(u *model.User, r *model.Repo, link string) error { + log.Info(fmt.Sprintf("Staring deactivating for bitbucketServer user: %s repo: %s link: %s", u.Login, r.Name, link)) + client := NewClientWithToken(&bs.Consumer, u.Token) + err := bs.DeleteHook(client, r.Owner, r.Name, "com.atlassian.stash.plugin.stash-web-post-receive-hooks-plugin:postReceiveHook", link) + if err != nil { + return err + } + return nil +} + +func (bs *BitbucketServer) Hook(r *http.Request) (*model.Repo, *model.Build, error) { + log.Info("Staring hook for bitbucketServer") + defer r.Body.Close() + contents, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Info(err) + } + + var hookPost postHook + json.Unmarshal(contents, &hookPost) + + buildModel := &model.Build{} + buildModel.Event = model.EventPush + buildModel.Ref = hookPost.RefChanges[0].RefID + buildModel.Author = hookPost.Changesets.Values[0].ToCommit.Author.EmailAddress + buildModel.Commit = hookPost.RefChanges[0].ToHash + buildModel.Avatar = avatarLink(hookPost.Changesets.Values[0].ToCommit.Author.EmailAddress) + + //All you really need is the name and owner. That's what creates the lookup key, so it needs to match the repo info. Just an FYI + repo := &model.Repo{} + repo.Name = hookPost.Repository.Slug + repo.Owner = hookPost.Repository.Project.Key + repo.AllowTag = false + repo.AllowDeploy = false + repo.AllowPull = false + repo.AllowPush = true + repo.FullName = fmt.Sprintf("%s/%s",hookPost.Repository.Project.Key,hookPost.Repository.Slug) + repo.Branch = "master" + repo.Kind = model.RepoGit + + return repo, buildModel, nil +} +func (bs *BitbucketServer) String() string { + return "bitbucketserver" +} + +type HookDetail struct { + Key string `"json:key"` + Name string `"json:name"` + Type string `"json:type"` + Description string `"json:description"` + Version string `"json:version"` + ConfigFormKey string `"json:configFormKey"` +} + +type Hook struct { + Enabled bool `"json:enabled"` + Details *HookDetail `"json:details"` +} + +// Enable hook for named repository +func (bs *BitbucketServer) CreateHook(client *http.Client, project, slug, hook_key, link string) (*Hook, error) { + + // Set hook + hookBytes := []byte(fmt.Sprintf(`{"hook-url-0":"%s"}`, link)) + + // Enable hook + enablePath := fmt.Sprintf("/rest/api/1.0/projects/%s/repos/%s/settings/hooks/%s/enabled", + project, slug, hook_key) + + doPut(client, bs.URL+enablePath, hookBytes) + + return nil, nil +} + +// Disable hook for named repository +func (bs *BitbucketServer) DeleteHook(client *http.Client, project, slug, hook_key, link string) error { + enablePath := fmt.Sprintf("/rest/api/1.0/projects/%s/repos/%s/settings/hooks/%s/enabled", + project, slug, hook_key) + doDelete(client, bs.URL+enablePath) + + return nil +} diff --git a/remote/bitbucketserver/client.go b/remote/bitbucketserver/client.go new file mode 100644 index 000000000..6c6f280b6 --- /dev/null +++ b/remote/bitbucketserver/client.go @@ -0,0 +1,52 @@ +package bitbucketserver + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + log "github.com/Sirupsen/logrus" + "github.com/mrjones/oauth" + "io/ioutil" + "net/http" +) + +func NewClient(ConsumerRSA string, ConsumerKey string, URL string) *oauth.Consumer { + //TODO: make this configurable + privateKeyFileContents, err := ioutil.ReadFile(ConsumerRSA) + log.Info("Tried to read the key") + if err != nil { + log.Error(err) + } + + block, _ := pem.Decode([]byte(privateKeyFileContents)) + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + log.Error(err) + } + + c := oauth.NewRSAConsumer( + ConsumerKey, + privateKey, + oauth.ServiceProvider{ + RequestTokenUrl: URL + "/plugins/servlet/oauth/request-token", + AuthorizeTokenUrl: URL + "/plugins/servlet/oauth/authorize", + AccessTokenUrl: URL + "/plugins/servlet/oauth/access-token", + HttpMethod: "POST", + }) + c.HttpClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + return c +} + +func NewClientWithToken(Consumer *oauth.Consumer, AccessToken string) *http.Client { + var token oauth.AccessToken + token.Token = AccessToken + client, err := Consumer.MakeHttpClient(&token) + if err != nil { + log.Error(err) + } + return client +} diff --git a/remote/bitbucketserver/helper.go b/remote/bitbucketserver/helper.go new file mode 100644 index 000000000..9c3985113 --- /dev/null +++ b/remote/bitbucketserver/helper.go @@ -0,0 +1,64 @@ +package bitbucketserver + +import ( + "bytes" + "crypto/md5" + "encoding/hex" + "fmt" + log "github.com/Sirupsen/logrus" + "io/ioutil" + "net/http" + "strings" +) + +func avatarLink(email string) (url string) { + hasher := md5.New() + hasher.Write([]byte(strings.ToLower(email))) + emailHash := fmt.Sprintf("%v", hex.EncodeToString(hasher.Sum(nil))) + avatarURL := fmt.Sprintf("https://www.gravatar.com/avatar/%s.jpg", emailHash) + log.Info(avatarURL) + return avatarURL +} + +func doPut(client *http.Client, url string, body []byte) { + request, err := http.NewRequest("PUT", url, bytes.NewBuffer(body)) + request.Header.Add("Content-Type", "application/json") + response, err := client.Do(request) + if err != nil { + log.Error(err) + } else { + defer response.Body.Close() + contents, err := ioutil.ReadAll(response.Body) + if err != nil { + log.Error(err) + } + fmt.Println("The calculated length is:", len(string(contents)), "for the url:", url) + fmt.Println(" ", response.StatusCode) + hdr := response.Header + for key, value := range hdr { + fmt.Println(" ", key, ":", value) + } + fmt.Println(string(contents)) + } +} + +func doDelete(client *http.Client, url string) { + request, err := http.NewRequest("DELETE", url, nil) + response, err := client.Do(request) + if err != nil { + log.Error(err) + } else { + defer response.Body.Close() + contents, err := ioutil.ReadAll(response.Body) + if err != nil { + log.Error(err) + } + fmt.Println("The calculated length is:", len(string(contents)), "for the url:", url) + fmt.Println(" ", response.StatusCode) + hdr := response.Header + for key, value := range hdr { + fmt.Println(" ", key, ":", value) + } + fmt.Println(string(contents)) + } +} diff --git a/remote/bitbucketserver/types.go b/remote/bitbucketserver/types.go new file mode 100644 index 000000000..3a1fb9083 --- /dev/null +++ b/remote/bitbucketserver/types.go @@ -0,0 +1,172 @@ +package bitbucketserver + +type postHook struct { + Changesets struct { + Filter interface{} `json:"filter"` + IsLastPage bool `json:"isLastPage"` + Limit int `json:"limit"` + Size int `json:"size"` + Start int `json:"start"` + Values []struct { + Changes struct { + Filter interface{} `json:"filter"` + IsLastPage bool `json:"isLastPage"` + Limit int `json:"limit"` + Size int `json:"size"` + Start int `json:"start"` + Values []struct { + ContentID string `json:"contentId"` + Executable bool `json:"executable"` + Link struct { + Rel string `json:"rel"` + URL string `json:"url"` + } `json:"link"` + NodeType string `json:"nodeType"` + Path struct { + Components []string `json:"components"` + Extension string `json:"extension"` + Name string `json:"name"` + Parent string `json:"parent"` + ToString string `json:"toString"` + } `json:"path"` + PercentUnchanged int `json:"percentUnchanged"` + SrcExecutable bool `json:"srcExecutable"` + Type string `json:"type"` + } `json:"values"` + } `json:"changes"` + FromCommit struct { + DisplayID string `json:"displayId"` + ID string `json:"id"` + } `json:"fromCommit"` + Link struct { + Rel string `json:"rel"` + URL string `json:"url"` + } `json:"link"` + ToCommit struct { + Author struct { + EmailAddress string `json:"emailAddress"` + Name string `json:"name"` + } `json:"author"` + AuthorTimestamp int `json:"authorTimestamp"` + DisplayID string `json:"displayId"` + ID string `json:"id"` + Message string `json:"message"` + Parents []struct { + DisplayID string `json:"displayId"` + ID string `json:"id"` + } `json:"parents"` + } `json:"toCommit"` + } `json:"values"` + } `json:"changesets"` + RefChanges []struct { + FromHash string `json:"fromHash"` + RefID string `json:"refId"` + ToHash string `json:"toHash"` + Type string `json:"type"` + } `json:"refChanges"` + Repository struct { + Forkable bool `json:"forkable"` + ID int `json:"id"` + Name string `json:"name"` + Project struct { + ID int `json:"id"` + IsPersonal bool `json:"isPersonal"` + Key string `json:"key"` + Name string `json:"name"` + Public bool `json:"public"` + Type string `json:"type"` + } `json:"project"` + Public bool `json:"public"` + ScmID string `json:"scmId"` + Slug string `json:"slug"` + State string `json:"state"` + StatusMessage string `json:"statusMessage"` + } `json:"repository"` +} + +type Repos struct { + IsLastPage bool `json:"isLastPage"` + Limit int `json:"limit"` + Size int `json:"size"` + Start int `json:"start"` + Values []struct { + Forkable bool `json:"forkable"` + ID int `json:"id"` + Links struct { + Clone []struct { + Href string `json:"href"` + Name string `json:"name"` + } `json:"clone"` + Self []struct { + Href string `json:"href"` + } `json:"self"` + } `json:"links"` + Name string `json:"name"` + Project struct { + Description string `json:"description"` + ID int `json:"id"` + Key string `json:"key"` + Links struct { + Self []struct { + Href string `json:"href"` + } `json:"self"` + } `json:"links"` + Name string `json:"name"` + Public bool `json:"public"` + Type string `json:"type"` + } `json:"project"` + Public bool `json:"public"` + ScmID string `json:"scmId"` + Slug string `json:"slug"` + State string `json:"state"` + StatusMessage string `json:"statusMessage"` + } `json:"values"` +} + +type User struct { + Active bool `json:"active"` + DisplayName string `json:"displayName"` + EmailAddress string `json:"emailAddress"` + ID int `json:"id"` + Links struct { + Self []struct { + Href string `json:"href"` + } `json:"self"` + } `json:"links"` + Name string `json:"name"` + Slug string `json:"slug"` + Type string `json:"type"` +} + +type BSRepo struct { + Forkable bool `json:"forkable"` + ID int `json:"id"` + Links struct { + Clone []struct { + Href string `json:"href"` + Name string `json:"name"` + } `json:"clone"` + Self []struct { + Href string `json:"href"` + } `json:"self"` + } `json:"links"` + Name string `json:"name"` + Project struct { + Description string `json:"description"` + ID int `json:"id"` + Key string `json:"key"` + Links struct { + Self []struct { + Href string `json:"href"` + } `json:"self"` + } `json:"links"` + Name string `json:"name"` + Public bool `json:"public"` + Type string `json:"type"` + } `json:"project"` + Public bool `json:"public"` + ScmID string `json:"scmId"` + Slug string `json:"slug"` + State string `json:"state"` + StatusMessage string `json:"statusMessage"` +} diff --git a/router/middleware/remote.go b/router/middleware/remote.go index 0c49ae361..5f245676a 100644 --- a/router/middleware/remote.go +++ b/router/middleware/remote.go @@ -10,6 +10,7 @@ import ( "github.com/Sirupsen/logrus" "github.com/gin-gonic/gin" "github.com/ianschenck/envflag" + "github.com/drone/drone/remote/bitbucketserver" ) var ( @@ -34,6 +35,8 @@ func Remote() gin.HandlerFunc { remote_ = gogs.Load(*config) case "gitlab": remote_ = gitlab.Load(*config) + case "bitbucketserver": + remote_ = bitbucketserver.Load(*config) default: logrus.Fatalln("remote configuraiton not found") } diff --git a/vendor/github.com/mrjones/oauth/README.md b/vendor/github.com/mrjones/oauth/README.md new file mode 100644 index 000000000..6edab0918 --- /dev/null +++ b/vendor/github.com/mrjones/oauth/README.md @@ -0,0 +1,49 @@ +OAuth 1.0 Library for [Go](http://golang.org) +======================== + +[![GoDoc](http://godoc.org/github.com/mrjones/oauth?status.png)](http://godoc.org/github.com/mrjones/oauth) + +(If you need an OAuth 2.0 library, check out: http://code.google.com/p/goauth2/) + +Developing your own apps, with this library +------------------------------------------- + +* First, install the library + + go get github.com/mrjones/oauth + +* Then, check out the comments in oauth.go + +* Or, have a look at the examples: + + * Netflix + + go run examples/netflix/netflix.go --consumerkey [key] --consumersecret [secret] --appname [appname] + + * Twitter + + Command line: + + go run examples/twitter/twitter.go --consumerkey [key] --consumersecret [secret] + + Or, in the browser (using an HTTP server): + + go run examples/twitterserver/twitterserver.go --consumerkey [key] --consumersecret [secret] --port 8888 + + * The Google Latitude example is broken, now that Google uses OAuth 2.0 + +Contributing to this library +---------------------------- + +* Please install the pre-commit hook, which will run tests, and go-fmt before committing. + + ln -s $PWD/pre-commit.sh .git/hooks/pre-commit + +* Running tests and building is as you'd expect: + + go test *.go + go build *.go + + + + diff --git a/vendor/github.com/mrjones/oauth/oauth.go b/vendor/github.com/mrjones/oauth/oauth.go new file mode 100644 index 000000000..c9b310980 --- /dev/null +++ b/vendor/github.com/mrjones/oauth/oauth.go @@ -0,0 +1,1389 @@ +// OAuth 1.0 consumer implementation. +// See http://www.oauth.net and RFC 5849 +// +// There are typically three parties involved in an OAuth exchange: +// (1) The "Service Provider" (e.g. Google, Twitter, NetFlix) who operates the +// service where the data resides. +// (2) The "End User" who owns that data, and wants to grant access to a third-party. +// (3) That third-party who wants access to the data (after first being authorized by +// the user). This third-party is referred to as the "Consumer" in OAuth +// terminology. +// +// This library is designed to help implement the third-party consumer by handling the +// low-level authentication tasks, and allowing for authenticated requests to the +// service provider on behalf of the user. +// +// Caveats: +// - Currently only supports HMAC and RSA signatures. +// - Currently only supports SHA1 and SHA256 hashes. +// - Currently only supports OAuth 1.0 +// +// Overview of how to use this library: +// (1) First create a new Consumer instance with the NewConsumer function +// (2) Get a RequestToken, and "authorization url" from GetRequestTokenAndUrl() +// (3) Save the RequestToken, you will need it again in step 6. +// (4) Redirect the user to the "authorization url" from step 2, where they will +// authorize your access to the service provider. +// (5) Wait. You will be called back on the CallbackUrl that you provide, and you +// will recieve a "verification code". +// (6) Call AuthorizeToken() with the RequestToken from step 2 and the +// "verification code" from step 5. +// (7) You will get back an AccessToken. Save this for as long as you need access +// to the user's data, and treat it like a password; it is a secret. +// (8) You can now throw away the RequestToken from step 2, it is no longer +// necessary. +// (9) Call "MakeHttpClient" using the AccessToken from step 7 to get an +// HTTP client which can access protected resources. +package oauth + +import ( + "bytes" + "crypto" + "crypto/hmac" + cryptoRand "crypto/rand" + "crypto/rsa" + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "math/rand" + "mime/multipart" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "sync" + "time" + log "github.com/Sirupsen/logrus" +) + +const ( + OAUTH_VERSION = "1.0" + SIGNATURE_METHOD_HMAC = "HMAC-" + SIGNATURE_METHOD_RSA = "RSA-" + + HTTP_AUTH_HEADER = "Authorization" + OAUTH_HEADER = "OAuth " + BODY_HASH_PARAM = "oauth_body_hash" + CALLBACK_PARAM = "oauth_callback" + CONSUMER_KEY_PARAM = "oauth_consumer_key" + NONCE_PARAM = "oauth_nonce" + SESSION_HANDLE_PARAM = "oauth_session_handle" + SIGNATURE_METHOD_PARAM = "oauth_signature_method" + SIGNATURE_PARAM = "oauth_signature" + TIMESTAMP_PARAM = "oauth_timestamp" + TOKEN_PARAM = "oauth_token" + TOKEN_SECRET_PARAM = "oauth_token_secret" + VERIFIER_PARAM = "oauth_verifier" + VERSION_PARAM = "oauth_version" +) + +var HASH_METHOD_MAP = map[crypto.Hash]string{ + crypto.SHA1: "SHA1", + crypto.SHA256: "SHA256", +} + +// TODO(mrjones) Do we definitely want separate "Request" and "Access" token classes? +// They're identical structurally, but used for different purposes. +type RequestToken struct { + Token string + Secret string +} + +type AccessToken struct { + Token string + Secret string + AdditionalData map[string]string +} + +type DataLocation int + +const ( + LOC_BODY DataLocation = iota + 1 + LOC_URL + LOC_MULTIPART + LOC_JSON + LOC_XML +) + +// Information about how to contact the service provider (see #1 above). +// You usually find all of these URLs by reading the documentation for the service +// that you're trying to connect to. +// Some common examples are: +// (1) Google, standard APIs: +// http://code.google.com/apis/accounts/docs/OAuth_ref.html +// - RequestTokenUrl: https://www.google.com/accounts/OAuthGetRequestToken +// - AuthorizeTokenUrl: https://www.google.com/accounts/OAuthAuthorizeToken +// - AccessTokenUrl: https://www.google.com/accounts/OAuthGetAccessToken +// Note: Some Google APIs (for example, Google Latitude) use different values for +// one or more of those URLs. +// (2) Twitter API: +// http://dev.twitter.com/pages/auth +// - RequestTokenUrl: http://api.twitter.com/oauth/request_token +// - AuthorizeTokenUrl: https://api.twitter.com/oauth/authorize +// - AccessTokenUrl: https://api.twitter.com/oauth/access_token +// (3) NetFlix API: +// http://developer.netflix.com/docs/Security +// - RequestTokenUrl: http://api.netflix.com/oauth/request_token +// - AuthroizeTokenUrl: https://api-user.netflix.com/oauth/login +// - AccessTokenUrl: http://api.netflix.com/oauth/access_token +// Set HttpMethod if the service provider requires a different HTTP method +// to be used for OAuth token requests +type ServiceProvider struct { + RequestTokenUrl string + AuthorizeTokenUrl string + AccessTokenUrl string + HttpMethod string + BodyHash bool +} + +func (sp *ServiceProvider) httpMethod() string { + if sp.HttpMethod != "" { + return sp.HttpMethod + } + + return "GET" +} + +// lockedNonceGenerator wraps a non-reentrant random number generator with a +// lock +type lockedNonceGenerator struct { + nonceGenerator nonceGenerator + lock sync.Mutex +} + +func newLockedNonceGenerator(c clock) *lockedNonceGenerator { + return &lockedNonceGenerator{ + nonceGenerator: rand.New(rand.NewSource(c.Nanos())), + } +} + +func (n *lockedNonceGenerator) Int63() int64 { + n.lock.Lock() + r := n.nonceGenerator.Int63() + n.lock.Unlock() + return r +} + +// Consumers are stateless, you can call the various methods (GetRequestTokenAndUrl, +// AuthorizeToken, and Get) on various different instances of Consumers *as long as +// they were set up in the same way.* It is up to you, as the caller to persist the +// necessary state (RequestTokens and AccessTokens). +type Consumer struct { + // Some ServiceProviders require extra parameters to be passed for various reasons. + // For example Google APIs require you to set a scope= parameter to specify how much + // access is being granted. The proper values for scope= depend on the service: + // For more, see: http://code.google.com/apis/accounts/docs/OAuth.html#prepScope + AdditionalParams map[string]string + + // The rest of this class is configured via the NewConsumer function. + consumerKey string + serviceProvider ServiceProvider + + // Some APIs (e.g. Netflix) aren't quite standard OAuth, and require passing + // additional parameters when authorizing the request token. For most APIs + // this field can be ignored. For Netflix, do something like: + // consumer.AdditionalAuthorizationUrlParams = map[string]string{ + // "application_name": "YourAppName", + // "oauth_consumer_key": "YourConsumerKey", + // } + AdditionalAuthorizationUrlParams map[string]string + + debug bool + + // Defaults to http.Client{}, can be overridden (e.g. for testing) as necessary + HttpClient HttpClient + + // Some APIs (e.g. Intuit/Quickbooks) require sending additional headers along with + // requests. (like "Accept" to specify the response type as XML or JSON) Note that this + // will only *add* headers, not set existing ones. + AdditionalHeaders map[string][]string + + // Private seams for mocking dependencies when testing + clock clock + // Seeded generators are not reentrant + nonceGenerator nonceGenerator + signer signer +} + +func newConsumer(consumerKey string, serviceProvider ServiceProvider, httpClient *http.Client) *Consumer { + clock := &defaultClock{} + if httpClient == nil { + httpClient = &http.Client{} + } + return &Consumer{ + consumerKey: consumerKey, + serviceProvider: serviceProvider, + clock: clock, + HttpClient: httpClient, + nonceGenerator: newLockedNonceGenerator(clock), + + AdditionalParams: make(map[string]string), + AdditionalAuthorizationUrlParams: make(map[string]string), + } +} + +// Creates a new Consumer instance, with a HMAC-SHA1 signer +// - consumerKey and consumerSecret: +// values you should obtain from the ServiceProvider when you register your +// application. +// +// - serviceProvider: +// see the documentation for ServiceProvider for how to create this. +// +func NewConsumer(consumerKey string, consumerSecret string, + serviceProvider ServiceProvider) *Consumer { + consumer := newConsumer(consumerKey, serviceProvider, nil) + + consumer.signer = &HMACSigner{ + consumerSecret: consumerSecret, + hashFunc: crypto.SHA1, + } + + return consumer +} + +// Creates a new Consumer instance, with a HMAC-SHA1 signer +// - consumerKey and consumerSecret: +// values you should obtain from the ServiceProvider when you register your +// application. +// +// - serviceProvider: +// see the documentation for ServiceProvider for how to create this. +// +// - httpClient: +// Provides a custom implementation of the httpClient used under the hood +// to make the request. This is especially useful if you want to use +// Google App Engine. +// +func NewCustomHttpClientConsumer(consumerKey string, consumerSecret string, + serviceProvider ServiceProvider, httpClient *http.Client) *Consumer { + consumer := newConsumer(consumerKey, serviceProvider, httpClient) + + consumer.signer = &HMACSigner{ + consumerSecret: consumerSecret, + hashFunc: crypto.SHA1, + } + + return consumer +} + +// Creates a new Consumer instance, with a HMAC signer +// - consumerKey and consumerSecret: +// values you should obtain from the ServiceProvider when you register your +// application. +// +// - hashFunc: +// the crypto.Hash to use for signatures +// +// - serviceProvider: +// see the documentation for ServiceProvider for how to create this. +// +// - httpClient: +// Provides a custom implementation of the httpClient used under the hood +// to make the request. This is especially useful if you want to use +// Google App Engine. Can be nil for default. +// +func NewCustomConsumer(consumerKey string, consumerSecret string, + hashFunc crypto.Hash, serviceProvider ServiceProvider, + httpClient *http.Client) *Consumer { + consumer := newConsumer(consumerKey, serviceProvider, httpClient) + + consumer.signer = &HMACSigner{ + consumerSecret: consumerSecret, + hashFunc: hashFunc, + } + + return consumer +} + +// Creates a new Consumer instance, with a RSA-SHA1 signer +// - consumerKey: +// value you should obtain from the ServiceProvider when you register your +// application. +// +// - privateKey: +// the private key to use for signatures +// +// - serviceProvider: +// see the documentation for ServiceProvider for how to create this. +// +func NewRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey, + serviceProvider ServiceProvider) *Consumer { + consumer := newConsumer(consumerKey, serviceProvider, nil) + + consumer.signer = &RSASigner{ + privateKey: privateKey, + hashFunc: crypto.SHA1, + rand: cryptoRand.Reader, + } + + return consumer +} + +// Creates a new Consumer instance, with a RSA signer +// - consumerKey: +// value you should obtain from the ServiceProvider when you register your +// application. +// +// - privateKey: +// the private key to use for signatures +// +// - hashFunc: +// the crypto.Hash to use for signatures +// +// - serviceProvider: +// see the documentation for ServiceProvider for how to create this. +// +// - httpClient: +// Provides a custom implementation of the httpClient used under the hood +// to make the request. This is especially useful if you want to use +// Google App Engine. Can be nil for default. +// +func NewCustomRSAConsumer(consumerKey string, privateKey *rsa.PrivateKey, + hashFunc crypto.Hash, serviceProvider ServiceProvider, + httpClient *http.Client) *Consumer { + consumer := newConsumer(consumerKey, serviceProvider, httpClient) + + consumer.signer = &RSASigner{ + privateKey: privateKey, + hashFunc: hashFunc, + rand: cryptoRand.Reader, + } + + return consumer +} + +// Kicks off the OAuth authorization process. +// - callbackUrl: +// Authorizing a token *requires* redirecting to the service provider. This is the +// URL which the service provider will redirect the user back to after that +// authorization is completed. The service provider will pass back a verification +// code which is necessary to complete the rest of the process (in AuthorizeToken). +// Notes on callbackUrl: +// - Some (all?) service providers allow for setting "oob" (for out-of-band) as a +// callback url. If this is set the service provider will present the +// verification code directly to the user, and you must provide a place for +// them to copy-and-paste it into. +// - Otherwise, the user will be redirected to callbackUrl in the browser, and +// will append a "oauth_verifier=" parameter. +// +// This function returns: +// - rtoken: +// A temporary RequestToken, used during the authorization process. You must save +// this since it will be necessary later in the process when calling +// AuthorizeToken(). +// +// - url: +// A URL that you should redirect the user to in order that they may authorize you +// to the service provider. +// +// - err: +// Set only if there was an error, nil otherwise. +func (c *Consumer) GetRequestTokenAndUrl(callbackUrl string) (rtoken *RequestToken, loginUrl string, err error) { + params := c.baseParams(c.consumerKey, c.AdditionalParams) + if callbackUrl != "" { + params.Add(CALLBACK_PARAM, callbackUrl) + } + log.Info(fmt.Sprintf("method: %s url: %s authparams: %s",c.serviceProvider.httpMethod(),c.serviceProvider.RequestTokenUrl, params)) + req := &request{ + method: c.serviceProvider.httpMethod(), + url: c.serviceProvider.RequestTokenUrl, + oauthParams: params, + } + if _, err := c.signRequest(req, ""); err != nil { // We don't have a token secret for the key yet + return nil, "", err + } + + resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.RequestTokenUrl, params) + if err != nil { + return nil, "", errors.New("getBody: " + err.Error()) + } + + requestToken, err := parseRequestToken(*resp) + if err != nil { + return nil, "", errors.New("parseRequestToken: " + err.Error()) + } + + loginParams := make(url.Values) + for k, v := range c.AdditionalAuthorizationUrlParams { + loginParams.Set(k, v) + } + loginParams.Set(TOKEN_PARAM, requestToken.Token) + + loginUrl = c.serviceProvider.AuthorizeTokenUrl + "?" + loginParams.Encode() + + return requestToken, loginUrl, nil +} + +// After the user has authorized you to the service provider, use this method to turn +// your temporary RequestToken into a permanent AccessToken. You must pass in two values: +// - rtoken: +// The RequestToken returned from GetRequestTokenAndUrl() +// +// - verificationCode: +// The string which passed back from the server, either as the oauth_verifier +// query param appended to callbackUrl *OR* a string manually entered by the user +// if callbackUrl is "oob" +// +// It will return: +// - atoken: +// A permanent AccessToken which can be used to access the user's data (until it is +// revoked by the user or the service provider). +// +// - err: +// Set only if there was an error, nil otherwise. +func (c *Consumer) AuthorizeToken(rtoken *RequestToken, verificationCode string) (atoken *AccessToken, err error) { + params := map[string]string{ + VERIFIER_PARAM: verificationCode, + TOKEN_PARAM: rtoken.Token, + } + + return c.makeAccessTokenRequest(params, rtoken.Secret) +} + +// Use the service provider to refresh the AccessToken for a given session. +// Note that this is only supported for service providers that manage an +// authorization session (e.g. Yahoo). +// +// Most providers do not return the SESSION_HANDLE_PARAM needed to refresh +// the token. +// +// See http://oauth.googlecode.com/svn/spec/ext/session/1.0/drafts/1/spec.html +// for more information. +// - accessToken: +// The AccessToken returned from AuthorizeToken() +// +// It will return: +// - atoken: +// An AccessToken which can be used to access the user's data (until it is +// revoked by the user or the service provider). +// +// - err: +// Set if accessToken does not contain the SESSION_HANDLE_PARAM needed to +// refresh the token, or if an error occurred when making the request. +func (c *Consumer) RefreshToken(accessToken *AccessToken) (atoken *AccessToken, err error) { + params := make(map[string]string) + sessionHandle, ok := accessToken.AdditionalData[SESSION_HANDLE_PARAM] + if !ok { + return nil, errors.New("Missing " + SESSION_HANDLE_PARAM + " in access token.") + } + params[SESSION_HANDLE_PARAM] = sessionHandle + params[TOKEN_PARAM] = accessToken.Token + + return c.makeAccessTokenRequest(params, accessToken.Secret) +} + +// Use the service provider to obtain an AccessToken for a given session +// - params: +// The access token request paramters. +// +// - secret: +// Secret key to use when signing the access token request. +// +// It will return: +// - atoken +// An AccessToken which can be used to access the user's data (until it is +// revoked by the user or the service provider). +// +// - err: +// Set only if there was an error, nil otherwise. +func (c *Consumer) makeAccessTokenRequest(params map[string]string, secret string) (atoken *AccessToken, err error) { + orderedParams := c.baseParams(c.consumerKey, c.AdditionalParams) + for key, value := range params { + orderedParams.Add(key, value) + } + + req := &request{ + method: c.serviceProvider.httpMethod(), + url: c.serviceProvider.AccessTokenUrl, + oauthParams: orderedParams, + } + if _, err := c.signRequest(req, secret); err != nil { + return nil, err + } + + resp, err := c.getBody(c.serviceProvider.httpMethod(), c.serviceProvider.AccessTokenUrl, orderedParams) + if err != nil { + return nil, err + } + + return parseAccessToken(*resp) +} + +type RoundTripper struct { + consumer *Consumer + token *AccessToken +} + +func (c *Consumer) MakeRoundTripper(token *AccessToken) (*RoundTripper, error) { + return &RoundTripper{consumer: c, token: token}, nil +} + +func (c *Consumer) MakeHttpClient(token *AccessToken) (*http.Client, error) { + return &http.Client{ + Transport: &RoundTripper{consumer: c, token: token}, + }, nil +} + +// ** DEPRECATED ** +// Please call Get on the http client returned by MakeHttpClient instead! +// +// Executes an HTTP Get, authorized via the AccessToken. +// - url: +// The base url, without any query params, which is being accessed +// +// - userParams: +// Any key=value params to be included in the query string +// +// - token: +// The AccessToken returned by AuthorizeToken() +// +// This method returns: +// - resp: +// The HTTP Response resulting from making this request. +// +// - err: +// Set only if there was an error, nil otherwise. +func (c *Consumer) Get(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("GET", url, LOC_URL, "", userParams, token) +} + +func encodeUserParams(userParams map[string]string) string { + data := url.Values{} + for k, v := range userParams { + data.Add(k, v) + } + return data.Encode() +} + +// ** DEPRECATED ** +// Please call "Post" on the http client returned by MakeHttpClient instead +func (c *Consumer) PostForm(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.PostWithBody(url, "", userParams, token) +} + +// ** DEPRECATED ** +// Please call "Post" on the http client returned by MakeHttpClient instead +func (c *Consumer) Post(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.PostWithBody(url, "", userParams, token) +} + +// ** DEPRECATED ** +// Please call "Post" on the http client returned by MakeHttpClient instead +func (c *Consumer) PostWithBody(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("POST", url, LOC_BODY, body, userParams, token) +} + +// ** DEPRECATED ** +// Please call "Do" on the http client returned by MakeHttpClient instead +// (and set the "Content-Type" header explicitly in the http.Request) +func (c *Consumer) PostJson(url string, body string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("POST", url, LOC_JSON, body, nil, token) +} + +// ** DEPRECATED ** +// Please call "Do" on the http client returned by MakeHttpClient instead +// (and set the "Content-Type" header explicitly in the http.Request) +func (c *Consumer) PostXML(url string, body string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("POST", url, LOC_XML, body, nil, token) +} + +// ** DEPRECATED ** +// Please call "Do" on the http client returned by MakeHttpClient instead +// (and setup the multipart data explicitly in the http.Request) +func (c *Consumer) PostMultipart(url, multipartName string, multipartData io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequestReader("POST", url, LOC_MULTIPART, 0, multipartName, multipartData, userParams, token) +} + +// ** DEPRECATED ** +// Please call "Delete" on the http client returned by MakeHttpClient instead +func (c *Consumer) Delete(url string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("DELETE", url, LOC_URL, "", userParams, token) +} + +// ** DEPRECATED ** +// Please call "Put" on the http client returned by MakeHttpClient instead +func (c *Consumer) Put(url string, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequest("PUT", url, LOC_URL, body, userParams, token) +} + +func (c *Consumer) Debug(enabled bool) { + c.debug = enabled + c.signer.Debug(enabled) +} + +type pair struct { + key string + value string +} + +type pairs []pair + +func (p pairs) Len() int { return len(p) } +func (p pairs) Less(i, j int) bool { return p[i].key < p[j].key } +func (p pairs) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +// This function has basically turned into a backwards compatibility layer +// between the old API (where clients explicitly called consumer.Get() +// consumer.Post() etc), and the new API (which takes actual http.Requests) +// +// So, here we construct the appropriate HTTP request for the inputs. +func (c *Consumer) makeAuthorizedRequestReader(method string, urlString string, dataLocation DataLocation, contentLength int, multipartName string, body io.ReadCloser, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + urlObject, err := url.Parse(urlString) + if err != nil { + return nil, err + } + + request := &http.Request{ + Method: method, + URL: urlObject, + Header: http.Header{}, + Body: body, + ContentLength: int64(contentLength), + } + + vals := url.Values{} + for k, v := range userParams { + vals.Add(k, v) + } + + if dataLocation != LOC_BODY { + request.URL.RawQuery = vals.Encode() + request.URL.RawQuery = strings.Replace( + request.URL.RawQuery, ";", "%3B", -1) + + } else { + // TODO(mrjones): validate that we're not overrideing an exising body? + request.Body = ioutil.NopCloser(strings.NewReader(vals.Encode())) + request.ContentLength = int64(len(vals.Encode())) + } + + for k, vs := range c.AdditionalHeaders { + for _, v := range vs { + request.Header.Set(k, v) + } + } + + if dataLocation == LOC_BODY { + request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + + if dataLocation == LOC_JSON { + request.Header.Set("Content-Type", "application/json") + } + + if dataLocation == LOC_XML { + request.Header.Set("Content-Type", "application/xml") + } + + if dataLocation == LOC_MULTIPART { + pipeReader, pipeWriter := io.Pipe() + writer := multipart.NewWriter(pipeWriter) + if request.URL.Host == "www.mrjon.es" && + request.URL.Path == "/unittest" { + writer.SetBoundary("UNITTESTBOUNDARY") + } + go func(body io.Reader) { + part, err := writer.CreateFormFile(multipartName, "/no/matter") + if err != nil { + writer.Close() + pipeWriter.CloseWithError(err) + return + } + _, err = io.Copy(part, body) + if err != nil { + writer.Close() + pipeWriter.CloseWithError(err) + return + } + writer.Close() + pipeWriter.Close() + }(body) + request.Body = pipeReader + request.Header.Set("Content-Type", writer.FormDataContentType()) + } + + rt := RoundTripper{consumer: c, token: token} + + resp, err = rt.RoundTrip(request) + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + defer resp.Body.Close() + bytes, _ := ioutil.ReadAll(resp.Body) + + return resp, HTTPExecuteError{ + RequestHeaders: "", + ResponseBodyBytes: bytes, + Status: resp.Status, + StatusCode: resp.StatusCode, + } + } + + return resp, nil +} + +// cloneReq clones the src http.Request, making deep copies of the Header and +// the URL but shallow copies of everything else +func cloneReq(src *http.Request) *http.Request { + dst := &http.Request{} + *dst = *src + + dst.Header = make(http.Header, len(src.Header)) + for k, s := range src.Header { + dst.Header[k] = append([]string(nil), s...) + } + + if src.URL != nil { + dst.URL = cloneURL(src.URL) + } + + return dst +} + +// cloneURL shallow clones the src *url.URL +func cloneURL(src *url.URL) *url.URL { + dst := &url.URL{} + *dst = *src + + return dst +} + +func canonicalizeUrl(u *url.URL) string { + var buf bytes.Buffer + buf.WriteString(u.Scheme) + buf.WriteString("://") + buf.WriteString(u.Host) + buf.WriteString(u.Path) + + return buf.String() +} + +func parseBody(request *http.Request) (map[string]string, error) { + userParams := map[string]string{} + + // TODO(mrjones): factor parameter extraction into a separate method + if request.Header.Get("Content-Type") != + "application/x-www-form-urlencoded" { + // Most of the time we get parameters from the query string: + for k, vs := range request.URL.Query() { + if len(vs) != 1 { + return nil, fmt.Errorf("Must have exactly one value per param") + } + + userParams[k] = vs[0] + } + } else { + // x-www-form-urlencoded parameters come from the body instead: + defer request.Body.Close() + originalBody, err := ioutil.ReadAll(request.Body) + if err != nil { + return nil, err + } + + // If there was a body, we have to re-install it + // (because we've ruined it by reading it). + request.Body = ioutil.NopCloser(bytes.NewReader(originalBody)) + + params, err := url.ParseQuery(string(originalBody)) + if err != nil { + return nil, err + } + + for k, vs := range params { + if len(vs) != 1 { + return nil, fmt.Errorf("Must have exactly one value per param") + } + + userParams[k] = vs[0] + } + } + + return userParams, nil +} + +func paramsToSortedPairs(params map[string]string) pairs { + // Sort parameters alphabetically + paramPairs := make(pairs, len(params)) + i := 0 + for key, value := range params { + paramPairs[i] = pair{key: key, value: value} + i++ + } + sort.Sort(paramPairs) + + return paramPairs +} + +func calculateBodyHash(request *http.Request, s signer) (string, error) { + if request.Header.Get("Content-Type") == + "application/x-www-form-urlencoded" { + return "", nil + } + + var originalBody []byte + + if request.Body != nil { + var err error + + defer request.Body.Close() + originalBody, err = ioutil.ReadAll(request.Body) + if err != nil { + return "", err + } + + // If there was a body, we have to re-install it + // (because we've ruined it by reading it). + request.Body = ioutil.NopCloser(bytes.NewReader(originalBody)) + } + + h := s.HashFunc().New() + h.Write(originalBody) + rawSignature := h.Sum(nil) + + return base64.StdEncoding.EncodeToString(rawSignature), nil +} + +func (rt *RoundTripper) RoundTrip(userRequest *http.Request) (*http.Response, error) { + serverRequest := cloneReq(userRequest) + + allParams := rt.consumer.baseParams( + rt.consumer.consumerKey, rt.consumer.AdditionalParams) + + // Do not add the "oauth_token" parameter, if the access token has not been + // specified. By omitting this parameter when it is not specified, allows + // two-legged OAuth calls. + if len(rt.token.Token) > 0 { + allParams.Add(TOKEN_PARAM, rt.token.Token) + } + + if rt.consumer.serviceProvider.BodyHash { + bodyHash, err := calculateBodyHash(serverRequest, rt.consumer.signer) + if err != nil { + return nil, err + } + + if bodyHash != "" { + allParams.Add(BODY_HASH_PARAM, bodyHash) + } + } + + authParams := allParams.Clone() + + // TODO(mrjones): put these directly into the paramPairs below? + userParams, err := parseBody(serverRequest) + if err != nil { + return nil, err + } + paramPairs := paramsToSortedPairs(userParams) + + for i := range paramPairs { + allParams.Add(paramPairs[i].key, paramPairs[i].value) + } + + signingURL := cloneURL(serverRequest.URL) + if host := serverRequest.Host; host != "" { + signingURL.Host = host + } + baseString := rt.consumer.requestString(serverRequest.Method, canonicalizeUrl(signingURL), allParams) + + signature, err := rt.consumer.signer.Sign(baseString, rt.token.Secret) + if err != nil { + return nil, err + } + + authParams.Add(SIGNATURE_PARAM, signature) + + // Set auth header. + oauthHdr := OAUTH_HEADER + for pos, key := range authParams.Keys() { + for innerPos, value := range authParams.Get(key) { + if pos+innerPos > 0 { + oauthHdr += "," + } + oauthHdr += key + "=\"" + value + "\"" + } + } + serverRequest.Header.Add(HTTP_AUTH_HEADER, oauthHdr) + + if rt.consumer.debug { + fmt.Printf("Request: %v\n", serverRequest) + } + + resp, err := rt.consumer.HttpClient.Do(serverRequest) + + if err != nil { + return resp, err + } + + return resp, nil +} + +func (c *Consumer) makeAuthorizedRequest(method string, url string, dataLocation DataLocation, body string, userParams map[string]string, token *AccessToken) (resp *http.Response, err error) { + return c.makeAuthorizedRequestReader(method, url, dataLocation, len(body), "", ioutil.NopCloser(strings.NewReader(body)), userParams, token) +} + +type request struct { + method string + url string + oauthParams *OrderedParams + userParams map[string]string +} + +type HttpClient interface { + Do(req *http.Request) (resp *http.Response, err error) +} + +type clock interface { + Seconds() int64 + Nanos() int64 +} + +type nonceGenerator interface { + Int63() int64 +} + +type key interface { + String() string +} + +type signer interface { + Sign(message string, tokenSecret string) (string, error) + Verify(message string, signature string) error + SignatureMethod() string + HashFunc() crypto.Hash + Debug(enabled bool) +} + +type defaultClock struct{} + +func (*defaultClock) Seconds() int64 { + return time.Now().Unix() +} + +func (*defaultClock) Nanos() int64 { + return time.Now().UnixNano() +} + +func (c *Consumer) signRequest(req *request, tokenSecret string) (*request, error) { + baseString := c.requestString(req.method, req.url, req.oauthParams) + + signature, err := c.signer.Sign(baseString, tokenSecret) + if err != nil { + return nil, err + } + + req.oauthParams.Add(SIGNATURE_PARAM, signature) + return req, nil +} + +// Obtains an AccessToken from the response of a service provider. +// - data: +// The response body. +// +// This method returns: +// - atoken: +// The AccessToken generated from the response body. +// +// - err: +// Set if an AccessToken could not be parsed from the given input. +func parseAccessToken(data string) (atoken *AccessToken, err error) { + parts, err := url.ParseQuery(data) + if err != nil { + return nil, err + } + + tokenParam := parts[TOKEN_PARAM] + parts.Del(TOKEN_PARAM) + if len(tokenParam) < 1 { + return nil, errors.New("Missing " + TOKEN_PARAM + " in response. " + + "Full response body: '" + data + "'") + } + tokenSecretParam := parts[TOKEN_SECRET_PARAM] + parts.Del(TOKEN_SECRET_PARAM) + if len(tokenSecretParam) < 1 { + return nil, errors.New("Missing " + TOKEN_SECRET_PARAM + " in response." + + "Full response body: '" + data + "'") + } + + additionalData := parseAdditionalData(parts) + + return &AccessToken{tokenParam[0], tokenSecretParam[0], additionalData}, nil +} + +func parseRequestToken(data string) (*RequestToken, error) { + parts, err := url.ParseQuery(data) + if err != nil { + return nil, err + } + + tokenParam := parts[TOKEN_PARAM] + if len(tokenParam) < 1 { + return nil, errors.New("Missing " + TOKEN_PARAM + " in response. " + + "Full response body: '" + data + "'") + } + tokenSecretParam := parts[TOKEN_SECRET_PARAM] + if len(tokenSecretParam) < 1 { + return nil, errors.New("Missing " + TOKEN_SECRET_PARAM + " in response." + + "Full response body: '" + data + "'") + } + return &RequestToken{tokenParam[0], tokenSecretParam[0]}, nil +} + +func (c *Consumer) baseParams(consumerKey string, additionalParams map[string]string) *OrderedParams { + params := NewOrderedParams() + params.Add(VERSION_PARAM, OAUTH_VERSION) + params.Add(SIGNATURE_METHOD_PARAM, c.signer.SignatureMethod()) + params.Add(TIMESTAMP_PARAM, strconv.FormatInt(c.clock.Seconds(), 10)) + params.Add(NONCE_PARAM, strconv.FormatInt(c.nonceGenerator.Int63(), 10)) + params.Add(CONSUMER_KEY_PARAM, consumerKey) + for key, value := range additionalParams { + params.Add(key, value) + } + return params +} + +func parseAdditionalData(parts url.Values) map[string]string { + params := make(map[string]string) + for key, value := range parts { + if len(value) > 0 { + params[key] = value[0] + } + } + return params +} + +type HMACSigner struct { + consumerSecret string + hashFunc crypto.Hash + debug bool +} + +func (s *HMACSigner) Debug(enabled bool) { + s.debug = enabled +} + +func (s *HMACSigner) Sign(message string, tokenSecret string) (string, error) { + key := escape(s.consumerSecret) + "&" + escape(tokenSecret) + if s.debug { + fmt.Println("Signing:", message) + fmt.Println("Key:", key) + } + + h := hmac.New(s.HashFunc().New, []byte(key)) + h.Write([]byte(message)) + rawSignature := h.Sum(nil) + + base64signature := base64.StdEncoding.EncodeToString(rawSignature) + if s.debug { + fmt.Println("Base64 signature:", base64signature) + } + return base64signature, nil +} + +func (s *HMACSigner) Verify(message string, signature string) error { + if s.debug { + fmt.Println("Verifying Base64 signature:", signature) + } + validSignature, err := s.Sign(message, "") + if err != nil { + return err + } + + if validSignature != signature { + return fmt.Errorf("signature did not match") + } + + return nil +} + +func (s *HMACSigner) SignatureMethod() string { + return SIGNATURE_METHOD_HMAC + HASH_METHOD_MAP[s.HashFunc()] +} + +func (s *HMACSigner) HashFunc() crypto.Hash { + return s.hashFunc +} + +type RSASigner struct { + debug bool + rand io.Reader + privateKey *rsa.PrivateKey + hashFunc crypto.Hash +} + +func (s *RSASigner) Debug(enabled bool) { + s.debug = enabled +} + +func (s *RSASigner) Sign(message string, tokenSecret string) (string, error) { + if s.debug { + fmt.Println("Signing:", message) + } + + h := s.HashFunc().New() + h.Write([]byte(message)) + digest := h.Sum(nil) + + signature, err := rsa.SignPKCS1v15(s.rand, s.privateKey, s.HashFunc(), digest) + if err != nil { + return "", nil + } + + base64signature := base64.StdEncoding.EncodeToString(signature) + if s.debug { + fmt.Println("Base64 signature:", base64signature) + } + + return base64signature, nil +} + +func (s *RSASigner) Verify(message string, base64signature string) error { + if s.debug { + fmt.Println("Verifying:", message) + fmt.Println("Verifying Base64 signature:", base64signature) + } + + h := s.HashFunc().New() + h.Write([]byte(message)) + digest := h.Sum(nil) + + signature, err := base64.StdEncoding.DecodeString(base64signature) + if err != nil { + return err + } + + return rsa.VerifyPKCS1v15(&s.privateKey.PublicKey, s.HashFunc(), digest, signature) +} + +func (s *RSASigner) SignatureMethod() string { + return SIGNATURE_METHOD_RSA + HASH_METHOD_MAP[s.HashFunc()] +} + +func (s *RSASigner) HashFunc() crypto.Hash { + return s.hashFunc +} + +func escape(s string) string { + t := make([]byte, 0, 3*len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + if isEscapable(c) { + t = append(t, '%') + t = append(t, "0123456789ABCDEF"[c>>4]) + t = append(t, "0123456789ABCDEF"[c&15]) + } else { + t = append(t, s[i]) + } + } + return string(t) +} + +func isEscapable(b byte) bool { + return !('A' <= b && b <= 'Z' || 'a' <= b && b <= 'z' || '0' <= b && b <= '9' || b == '-' || b == '.' || b == '_' || b == '~') + +} + +func (c *Consumer) requestString(method string, url string, params *OrderedParams) string { + result := method + "&" + escape(url) + for pos, key := range params.Keys() { + for innerPos, value := range params.Get(key) { + if pos+innerPos == 0 { + result += "&" + } else { + result += escape("&") + } + result += escape(fmt.Sprintf("%s=%s", key, value)) + } + } + return result +} + +func (c *Consumer) getBody(method, url string, oauthParams *OrderedParams) (*string, error) { + resp, err := c.httpExecute(method, url, "", 0, nil, oauthParams) + if err != nil { + return nil, errors.New("httpExecute: " + err.Error()) + } + bodyBytes, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, errors.New("ReadAll: " + err.Error()) + } + bodyStr := string(bodyBytes) + if c.debug { + fmt.Printf("STATUS: %d %s\n", resp.StatusCode, resp.Status) + fmt.Println("BODY RESPONSE: " + bodyStr) + } + return &bodyStr, nil +} + +// HTTPExecuteError signals that a call to httpExecute failed. +type HTTPExecuteError struct { + // RequestHeaders provides a stringified listing of request headers. + RequestHeaders string + // ResponseBodyBytes is the response read into a byte slice. + ResponseBodyBytes []byte + // Status is the status code string response. + Status string + // StatusCode is the parsed status code. + StatusCode int +} + +// Error provides a printable string description of an HTTPExecuteError. +func (e HTTPExecuteError) Error() string { + return "HTTP response is not 200/OK as expected. Actual response: \n" + + "\tResponse Status: '" + e.Status + "'\n" + + "\tResponse Code: " + strconv.Itoa(e.StatusCode) + "\n" + + "\tResponse Body: " + string(e.ResponseBodyBytes) + "\n" + + "\tRequest Headers: " + e.RequestHeaders +} + +func (c *Consumer) httpExecute( + method string, urlStr string, contentType string, contentLength int, body io.Reader, oauthParams *OrderedParams) (*http.Response, error) { + // Create base request. + req, err := http.NewRequest(method, urlStr, body) + if err != nil { + return nil, errors.New("NewRequest failed: " + err.Error()) + } + + // Set auth header. + req.Header = http.Header{} + oauthHdr := "OAuth " + for pos, key := range oauthParams.Keys() { + for innerPos, value := range oauthParams.Get(key) { + if pos+innerPos > 0 { + oauthHdr += "," + } + oauthHdr += key + "=\"" + value + "\"" + } + } + req.Header.Add("Authorization", oauthHdr) + + // Add additional custom headers + for key, vals := range c.AdditionalHeaders { + for _, val := range vals { + req.Header.Add(key, val) + } + } + + // Set contentType if passed. + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + + // Set contentLength if passed. + if contentLength > 0 { + req.Header.Set("Content-Length", strconv.Itoa(contentLength)) + } + + if c.debug { + fmt.Printf("Request: %v\n", req) + } + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, errors.New("Do: " + err.Error()) + } + + debugHeader := "" + for k, vals := range req.Header { + for _, val := range vals { + debugHeader += "[key: " + k + ", val: " + val + "]" + } + } + + // StatusMultipleChoices is 300, any 2xx response should be treated as success + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + defer resp.Body.Close() + bytes, _ := ioutil.ReadAll(resp.Body) + + return resp, HTTPExecuteError{ + RequestHeaders: debugHeader, + ResponseBodyBytes: bytes, + Status: resp.Status, + StatusCode: resp.StatusCode, + } + } + return resp, err +} + +// +// String Sorting helpers +// + +type ByValue []string + +func (a ByValue) Len() int { + return len(a) +} + +func (a ByValue) Swap(i, j int) { + a[i], a[j] = a[j], a[i] +} + +func (a ByValue) Less(i, j int) bool { + return a[i] < a[j] +} + +// +// ORDERED PARAMS +// + +type OrderedParams struct { + allParams map[string][]string + keyOrdering []string +} + +func NewOrderedParams() *OrderedParams { + return &OrderedParams{ + allParams: make(map[string][]string), + keyOrdering: make([]string, 0), + } +} + +func (o *OrderedParams) Get(key string) []string { + sort.Sort(ByValue(o.allParams[key])) + return o.allParams[key] +} + +func (o *OrderedParams) Keys() []string { + sort.Sort(o) + return o.keyOrdering +} + +func (o *OrderedParams) Add(key, value string) { + o.AddUnescaped(key, escape(value)) +} + +func (o *OrderedParams) AddUnescaped(key, value string) { + if _, exists := o.allParams[key]; !exists { + o.keyOrdering = append(o.keyOrdering, key) + o.allParams[key] = make([]string, 1) + o.allParams[key][0] = value + } else { + o.allParams[key] = append(o.allParams[key], value) + } +} + +func (o *OrderedParams) Len() int { + return len(o.keyOrdering) +} + +func (o *OrderedParams) Less(i int, j int) bool { + return o.keyOrdering[i] < o.keyOrdering[j] +} + +func (o *OrderedParams) Swap(i int, j int) { + o.keyOrdering[i], o.keyOrdering[j] = o.keyOrdering[j], o.keyOrdering[i] +} + +func (o *OrderedParams) Clone() *OrderedParams { + clone := NewOrderedParams() + for _, key := range o.Keys() { + for _, value := range o.Get(key) { + clone.AddUnescaped(key, value) + } + } + return clone +} diff --git a/vendor/github.com/mrjones/oauth/pre-commit.sh b/vendor/github.com/mrjones/oauth/pre-commit.sh new file mode 100755 index 000000000..91b9e8823 --- /dev/null +++ b/vendor/github.com/mrjones/oauth/pre-commit.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# ln -s $PWD/pre-commit.sh .git/hooks/pre-commit +go test *.go +RESULT=$? +if [[ $RESULT != 0 ]]; then + echo "REJECTING COMMIT (test failed with status: $RESULT)" + exit 1; +fi + +go fmt *.go +for e in $(ls examples); do + go build examples/$e/*.go + RESULT=$? + if [[ $RESULT != 0 ]]; then + echo "REJECTING COMMIT (Examples failed to compile)" + exit $RESULT; + fi + go fmt examples/$e/*.go +done + +exit 0 diff --git a/vendor/github.com/mrjones/oauth/provider.go b/vendor/github.com/mrjones/oauth/provider.go new file mode 100644 index 000000000..b15cb9874 --- /dev/null +++ b/vendor/github.com/mrjones/oauth/provider.go @@ -0,0 +1,147 @@ +package oauth + +import ( + "bytes" + "fmt" + "math" + "net/http" + "net/url" + "strconv" + "strings" +) + +// +// OAuth1 2-legged provider +// Contributed by https://github.com/jacobpgallagher +// + +// Provide an buffer reader which implements the Close() interface +type oauthBufferReader struct { + *bytes.Buffer +} + +// So that it implements the io.ReadCloser interface +func (m oauthBufferReader) Close() error { return nil } + +type ConsumerGetter func(key string, header map[string]string) (*Consumer, error) + +// Provider provides methods for a 2-legged Oauth1 provider +type Provider struct { + ConsumerGetter ConsumerGetter + + // For mocking + clock clock +} + +// NewProvider takes a function to get the consumer secret from a datastore. +// Returns a Provider +func NewProvider(secretGetter ConsumerGetter) *Provider { + provider := &Provider{ + secretGetter, + &defaultClock{}, + } + return provider +} + +// Combine a URL and Request to make the URL absolute +func makeURLAbs(url *url.URL, request *http.Request) { + if !url.IsAbs() { + url.Host = request.Host + if request.TLS != nil || request.Header.Get("X-Forwarded-Proto") == "https" { + url.Scheme = "https" + } else { + url.Scheme = "http" + } + } +} + +// IsAuthorized takes an *http.Request and returns a pointer to a string containing the consumer key, +// or nil if not authorized +func (provider *Provider) IsAuthorized(request *http.Request) (*string, error) { + var err error + var userParams map[string]string + + // start with the body/query params + userParams, err = parseBody(request) + if err != nil { + return nil, err + } + + // if the oauth params are in the Authorization header, grab them, and + // let them override what's in userParams + authHeader := request.Header.Get(HTTP_AUTH_HEADER) + if len(authHeader) > 6 && strings.EqualFold(OAUTH_HEADER, authHeader[0:6]) { + authHeader = authHeader[6:] + params := strings.Split(authHeader, ",") + for _, param := range params { + vals := strings.SplitN(param, "=", 2) + k := strings.Trim(vals[0], " ") + v := strings.Trim(strings.Trim(vals[1], "\""), " ") + if strings.HasPrefix(k, "oauth") { + userParams[k], err = url.QueryUnescape(v) + if err != nil { + return nil, err + } + } + } + } + + // pop the request's signature, it's not included in our signature + // calculation + oauthSignature, ok := userParams[SIGNATURE_PARAM] + if !ok { + return nil, fmt.Errorf("no oauth signature") + } + delete(userParams, SIGNATURE_PARAM) + + // Check the timestamp + oauthTimeNumber, err := strconv.Atoi(userParams[TIMESTAMP_PARAM]) + if err != nil { + return nil, err + } + if math.Abs(float64(int64(oauthTimeNumber)-provider.clock.Seconds())) > 5*60 { + return nil, fmt.Errorf("too much clock skew") + } + + // get the oauth consumer key + consumerKey, ok := userParams[CONSUMER_KEY_PARAM] + if !ok { + return nil, fmt.Errorf("no consumer key") + } + + // use it to create a consumer object + consumer, err := provider.ConsumerGetter(consumerKey, userParams) + if err != nil { + return nil, err + } + + // if our consumer supports bodyhash, check it + if consumer.serviceProvider.BodyHash { + bodyHash, err := calculateBodyHash(request, consumer.signer) + if err != nil { + return nil, err + } + + sentHash, ok := userParams[BODY_HASH_PARAM] + + if bodyHash == "" && ok { + return nil, fmt.Errorf("body_hash must not be set") + } else if sentHash != bodyHash { + return nil, fmt.Errorf("body_hash mismatch") + } + } + + allParams := NewOrderedParams() + for key, value := range userParams { + allParams.Add(key, value) + } + + makeURLAbs(request.URL, request) + baseString := consumer.requestString(request.Method, canonicalizeUrl(request.URL), allParams) + err = consumer.signer.Verify(baseString, oauthSignature) + if err != nil { + return nil, err + } + + return &consumerKey, nil +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 7ca09d605..c41d3ed06 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -169,6 +169,12 @@ "revisionTime": "2016-03-07T18:57:06+09:00", "tree": true }, + { + "checksumSHA1": "spRLFk8daizvEzOmRsxlxeECdHI=", + "path": "github.com/mrjones/oauth", + "revision": "31f1e8e5addda51bc50ebfc8bb930d4642372654", + "revisionTime": "2016-04-05T23:58:02Z" + }, { "origin": "github.com/stretchr/testify/vendor/github.com/pmezard/go-difflib/difflib", "path": "github.com/pmezard/go-difflib/difflib",