mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-03-23 21:50:48 +02:00
Merge pull request #660 from oauth2-proxy/request-builder
Use builder pattern to simplify requests to external endpoints
This commit is contained in:
commit
d29766609b
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
## Changes since v6.0.0
|
## Changes since v6.0.0
|
||||||
|
|
||||||
|
- [#660](https://github.com/oauth2-proxy/oauth2-proxy/pull/660) Use builder pattern to simplify requests to external endpoints (@JoelSpeed)
|
||||||
- [#591](https://github.com/oauth2-proxy/oauth2-proxy/pull/591) Introduce upstream package with new reverse proxy implementation (@JoelSpeed)
|
- [#591](https://github.com/oauth2-proxy/oauth2-proxy/pull/591) Introduce upstream package with new reverse proxy implementation (@JoelSpeed)
|
||||||
- [#576](https://github.com/oauth2-proxy/oauth2-proxy/pull/576) Separate Cookie validation out of main options validation (@JoelSpeed)
|
- [#576](https://github.com/oauth2-proxy/oauth2-proxy/pull/576) Separate Cookie validation out of main options validation (@JoelSpeed)
|
||||||
- [#656](https://github.com/oauth2-proxy/oauth2-proxy/pull/656) Split long session cookies more precisely (@NickMeves)
|
- [#656](https://github.com/oauth2-proxy/oauth2-proxy/pull/656) Split long session cookies more precisely (@NickMeves)
|
||||||
|
118
pkg/requests/builder.go
Normal file
118
pkg/requests/builder.go
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
package requests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Builder allows users to construct a request and then execute the
|
||||||
|
// request via Do().
|
||||||
|
// Do returns a Result which allows the user to get the body,
|
||||||
|
// unmarshal the body into an interface, or into a simplejson.Json.
|
||||||
|
type Builder interface {
|
||||||
|
WithContext(context.Context) Builder
|
||||||
|
WithBody(io.Reader) Builder
|
||||||
|
WithMethod(string) Builder
|
||||||
|
WithHeaders(http.Header) Builder
|
||||||
|
SetHeader(key, value string) Builder
|
||||||
|
Do() Result
|
||||||
|
}
|
||||||
|
|
||||||
|
type builder struct {
|
||||||
|
context context.Context
|
||||||
|
method string
|
||||||
|
endpoint string
|
||||||
|
body io.Reader
|
||||||
|
header http.Header
|
||||||
|
result *result
|
||||||
|
}
|
||||||
|
|
||||||
|
// New provides a new Builder for the given endpoint.
|
||||||
|
func New(endpoint string) Builder {
|
||||||
|
return &builder{
|
||||||
|
endpoint: endpoint,
|
||||||
|
method: "GET",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithContext adds a context to the request.
|
||||||
|
// If no context is provided, context.Background() is used instead.
|
||||||
|
func (r *builder) WithContext(ctx context.Context) Builder {
|
||||||
|
r.context = ctx
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithBody adds a body to the request.
|
||||||
|
func (r *builder) WithBody(body io.Reader) Builder {
|
||||||
|
r.body = body
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMethod sets the request method. Defaults to "GET".
|
||||||
|
func (r *builder) WithMethod(method string) Builder {
|
||||||
|
r.method = method
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHeaders replaces the request header map with the given header map.
|
||||||
|
func (r *builder) WithHeaders(header http.Header) Builder {
|
||||||
|
r.header = header
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHeader sets a single header to the given value.
|
||||||
|
// May be used to add multiple headers.
|
||||||
|
func (r *builder) SetHeader(key, value string) Builder {
|
||||||
|
if r.header == nil {
|
||||||
|
r.header = make(http.Header)
|
||||||
|
}
|
||||||
|
r.header.Set(key, value)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do performs the request and returns the response in its raw form.
|
||||||
|
// If the request has already been performed, returns the previous result.
|
||||||
|
// This will not allow you to repeat a request.
|
||||||
|
func (r *builder) Do() Result {
|
||||||
|
if r.result != nil {
|
||||||
|
// Request has already been done
|
||||||
|
return r.result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must provide a non-nil context to NewRequestWithContext
|
||||||
|
if r.context == nil {
|
||||||
|
r.context = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.do()
|
||||||
|
}
|
||||||
|
|
||||||
|
// do creates the request, executes it with the default client and extracts the
|
||||||
|
// the body into the response
|
||||||
|
func (r *builder) do() Result {
|
||||||
|
req, err := http.NewRequestWithContext(r.context, r.method, r.endpoint, r.body)
|
||||||
|
if err != nil {
|
||||||
|
r.result = &result{err: fmt.Errorf("error creating request: %v", err)}
|
||||||
|
return r.result
|
||||||
|
}
|
||||||
|
req.Header = r.header
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
r.result = &result{err: fmt.Errorf("error performing request: %v", err)}
|
||||||
|
return r.result
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
r.result = &result{err: fmt.Errorf("error reading response body: %v", err)}
|
||||||
|
return r.result
|
||||||
|
}
|
||||||
|
|
||||||
|
r.result = &result{response: resp, body: body}
|
||||||
|
return r.result
|
||||||
|
}
|
376
pkg/requests/builder_test.go
Normal file
376
pkg/requests/builder_test.go
Normal file
@ -0,0 +1,376 @@
|
|||||||
|
package requests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitly/go-simplejson"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Builder suite", func() {
|
||||||
|
var b Builder
|
||||||
|
getBuilder := func() Builder { return b }
|
||||||
|
|
||||||
|
baseHeaders := http.Header{
|
||||||
|
"Accept-Encoding": []string{"gzip"},
|
||||||
|
"User-Agent": []string{"Go-http-client/1.1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
// Most tests will request the server address
|
||||||
|
b = New(serverAddr + "/json/path")
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with a basic request", func() {
|
||||||
|
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||||
|
Method: "GET",
|
||||||
|
Header: baseHeaders,
|
||||||
|
Body: []byte{},
|
||||||
|
RequestURI: "/json/path",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with a context", func() {
|
||||||
|
var ctx context.Context
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
b = b.WithContext(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
cancel()
|
||||||
|
})
|
||||||
|
|
||||||
|
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||||
|
Method: "GET",
|
||||||
|
Header: baseHeaders,
|
||||||
|
Body: []byte{},
|
||||||
|
RequestURI: "/json/path",
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("if the context is cancelled", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
cancel()
|
||||||
|
})
|
||||||
|
|
||||||
|
assertRequestError(getBuilder, "context canceled")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with a body", func() {
|
||||||
|
const body = "{\"some\": \"body\"}"
|
||||||
|
header := baseHeaders.Clone()
|
||||||
|
header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
buf := bytes.NewBuffer([]byte(body))
|
||||||
|
b = b.WithBody(buf)
|
||||||
|
})
|
||||||
|
|
||||||
|
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||||
|
Method: "GET",
|
||||||
|
Header: header,
|
||||||
|
Body: []byte(body),
|
||||||
|
RequestURI: "/json/path",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with a method", func() {
|
||||||
|
Context("POST with a body", func() {
|
||||||
|
const body = "{\"some\": \"body\"}"
|
||||||
|
header := baseHeaders.Clone()
|
||||||
|
header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
buf := bytes.NewBuffer([]byte(body))
|
||||||
|
b = b.WithMethod("POST").WithBody(buf)
|
||||||
|
})
|
||||||
|
|
||||||
|
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||||
|
Method: "POST",
|
||||||
|
Header: header,
|
||||||
|
Body: []byte(body),
|
||||||
|
RequestURI: "/json/path",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("POST without a body", func() {
|
||||||
|
header := baseHeaders.Clone()
|
||||||
|
header.Set("Content-Length", "0")
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
b = b.WithMethod("POST")
|
||||||
|
})
|
||||||
|
|
||||||
|
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||||
|
Method: "POST",
|
||||||
|
Header: header,
|
||||||
|
Body: []byte{},
|
||||||
|
RequestURI: "/json/path",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("OPTIONS", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
b = b.WithMethod("OPTIONS")
|
||||||
|
})
|
||||||
|
|
||||||
|
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||||
|
Method: "OPTIONS",
|
||||||
|
Header: baseHeaders,
|
||||||
|
Body: []byte{},
|
||||||
|
RequestURI: "/json/path",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("INVALID-\\t-METHOD", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
b = b.WithMethod("INVALID-\t-METHOD")
|
||||||
|
})
|
||||||
|
|
||||||
|
assertRequestError(getBuilder, "error creating request: net/http: invalid method \"INVALID-\\t-METHOD\"")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with headers", func() {
|
||||||
|
Context("setting a header", func() {
|
||||||
|
header := baseHeaders.Clone()
|
||||||
|
header.Set("header", "value")
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
b = b.SetHeader("header", "value")
|
||||||
|
})
|
||||||
|
|
||||||
|
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||||
|
Method: "GET",
|
||||||
|
Header: header,
|
||||||
|
Body: []byte{},
|
||||||
|
RequestURI: "/json/path",
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("then replacing the headers", func() {
|
||||||
|
replacementHeaders := http.Header{
|
||||||
|
"Accept-Encoding": []string{"*"},
|
||||||
|
"User-Agent": []string{"test-agent"},
|
||||||
|
"Foo": []string{"bar, baz"},
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
b = b.WithHeaders(replacementHeaders)
|
||||||
|
})
|
||||||
|
|
||||||
|
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||||
|
Method: "GET",
|
||||||
|
Header: replacementHeaders,
|
||||||
|
Body: []byte{},
|
||||||
|
RequestURI: "/json/path",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("replacing the header", func() {
|
||||||
|
replacementHeaders := http.Header{
|
||||||
|
"Accept-Encoding": []string{"*"},
|
||||||
|
"User-Agent": []string{"test-agent"},
|
||||||
|
"Foo": []string{"bar, baz"},
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
b = b.WithHeaders(replacementHeaders)
|
||||||
|
})
|
||||||
|
|
||||||
|
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||||
|
Method: "GET",
|
||||||
|
Header: replacementHeaders,
|
||||||
|
Body: []byte{},
|
||||||
|
RequestURI: "/json/path",
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("then setting a header", func() {
|
||||||
|
header := replacementHeaders.Clone()
|
||||||
|
header.Set("User-Agent", "different-agent")
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
b = b.SetHeader("User-Agent", "different-agent")
|
||||||
|
})
|
||||||
|
|
||||||
|
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||||
|
Method: "GET",
|
||||||
|
Header: header,
|
||||||
|
Body: []byte{},
|
||||||
|
RequestURI: "/json/path",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("if the request has been completed and then modified", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
result := b.Do()
|
||||||
|
Expect(result.Error()).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
b.WithMethod("POST")
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("should not redo the request", func() {
|
||||||
|
assertSuccessfulRequest(getBuilder, testHTTPRequest{
|
||||||
|
Method: "GET",
|
||||||
|
Header: baseHeaders,
|
||||||
|
Body: []byte{},
|
||||||
|
RequestURI: "/json/path",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when the requested page is not found", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
b = New(serverAddr + "/not-found")
|
||||||
|
})
|
||||||
|
|
||||||
|
assertJSONError(getBuilder, "404 page not found")
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when the requested page is not valid JSON", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
b = New(serverAddr + "/string/path")
|
||||||
|
})
|
||||||
|
|
||||||
|
assertJSONError(getBuilder, "invalid character 'O' looking for beginning of value")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPRequest) {
|
||||||
|
Context("Do", func() {
|
||||||
|
var result Result
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
result = builder().Do()
|
||||||
|
Expect(result.Error()).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns a successful status", func() {
|
||||||
|
Expect(result.StatusCode()).To(Equal(http.StatusOK))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("made the expected request", func() {
|
||||||
|
actualRequest := testHTTPRequest{}
|
||||||
|
Expect(json.Unmarshal(result.Body(), &actualRequest)).To(Succeed())
|
||||||
|
|
||||||
|
Expect(actualRequest).To(Equal(expectedRequest))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("UnmarshalInto", func() {
|
||||||
|
var actualRequest testHTTPRequest
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
Expect(builder().Do().UnmarshalInto(&actualRequest)).To(Succeed())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("made the expected request", func() {
|
||||||
|
Expect(actualRequest).To(Equal(expectedRequest))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("UnmarshalJSON", func() {
|
||||||
|
var response *simplejson.Json
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
response, err = builder().Do().UnmarshalJSON()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("made the expected reqest", func() {
|
||||||
|
header := http.Header{}
|
||||||
|
for key, value := range response.Get("Header").MustMap() {
|
||||||
|
vs, ok := value.([]interface{})
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
svs := []string{}
|
||||||
|
for _, v := range vs {
|
||||||
|
sv, ok := v.(string)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
svs = append(svs, sv)
|
||||||
|
}
|
||||||
|
header[key] = svs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other json unmarhsallers base64 decode byte slices automatically
|
||||||
|
body, err := base64.StdEncoding.DecodeString(response.Get("Body").MustString())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
actualRequest := testHTTPRequest{
|
||||||
|
Method: response.Get("Method").MustString(),
|
||||||
|
Header: header,
|
||||||
|
Body: body,
|
||||||
|
RequestURI: response.Get("RequestURI").MustString(),
|
||||||
|
}
|
||||||
|
|
||||||
|
Expect(actualRequest).To(Equal(expectedRequest))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertRequestError(builder func() Builder, errorMessage string) {
|
||||||
|
Context("Do", func() {
|
||||||
|
It("returns an error", func() {
|
||||||
|
result := builder().Do()
|
||||||
|
Expect(result.Error()).To(MatchError(ContainSubstring(errorMessage)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("UnmarshalInto", func() {
|
||||||
|
It("returns an error", func() {
|
||||||
|
var actualRequest testHTTPRequest
|
||||||
|
err := builder().Do().UnmarshalInto(&actualRequest)
|
||||||
|
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
|
||||||
|
|
||||||
|
// Should be empty
|
||||||
|
Expect(actualRequest).To(Equal(testHTTPRequest{}))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("UnmarshalJSON", func() {
|
||||||
|
It("returns an error", func() {
|
||||||
|
resp, err := builder().Do().UnmarshalJSON()
|
||||||
|
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
|
||||||
|
Expect(resp).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertJSONError(builder func() Builder, errorMessage string) {
|
||||||
|
Context("Do", func() {
|
||||||
|
It("does not return an error", func() {
|
||||||
|
result := builder().Do()
|
||||||
|
Expect(result.Error()).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("UnmarshalInto", func() {
|
||||||
|
It("returns an error", func() {
|
||||||
|
var actualRequest testHTTPRequest
|
||||||
|
err := builder().Do().UnmarshalInto(&actualRequest)
|
||||||
|
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
|
||||||
|
|
||||||
|
// Should be empty
|
||||||
|
Expect(actualRequest).To(Equal(testHTTPRequest{}))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("UnmarshalJSON", func() {
|
||||||
|
It("returns an error", func() {
|
||||||
|
resp, err := builder().Do().UnmarshalJSON()
|
||||||
|
Expect(err).To(MatchError(ContainSubstring(errorMessage)))
|
||||||
|
Expect(resp).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
@ -1,74 +0,0 @@
|
|||||||
package requests
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/bitly/go-simplejson"
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Request parses the request body into a simplejson.Json object
|
|
||||||
func Request(req *http.Request) (*simplejson.Json, error) {
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("%s %s %s", req.Method, req.URL, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if body != nil {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("problem reading http request body: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return nil, fmt.Errorf("got %d %s", resp.StatusCode, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := simplejson.NewJson(body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error unmarshalling json: %w", err)
|
|
||||||
}
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestJSON parses the request body into the given interface
|
|
||||||
func RequestJSON(req *http.Request, v interface{}) error {
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("%s %s %s", req.Method, req.URL, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if body != nil {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error reading body from http response: %w", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return fmt.Errorf("got %d %s", resp.StatusCode, body)
|
|
||||||
}
|
|
||||||
return json.Unmarshal(body, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestUnparsedResponse performs a GET and returns the raw response object
|
|
||||||
func RequestUnparsedResponse(ctx context.Context, url string, header http.Header) (resp *http.Response, err error) {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error performing get request: %w", err)
|
|
||||||
}
|
|
||||||
req.Header = header
|
|
||||||
|
|
||||||
return http.DefaultClient.Do(req)
|
|
||||||
}
|
|
96
pkg/requests/requests_suite_test.go
Normal file
96
pkg/requests/requests_suite_test.go
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
package requests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
server *httptest.Server
|
||||||
|
serverAddr string
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequetsSuite(t *testing.T) {
|
||||||
|
logger.SetOutput(GinkgoWriter)
|
||||||
|
log.SetOutput(GinkgoWriter)
|
||||||
|
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Requests Suite")
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = BeforeSuite(func() {
|
||||||
|
// Set up a webserver that reflects requests
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.Handle("/json/", &testHTTPUpstream{})
|
||||||
|
mux.HandleFunc("/string/", func(rw http.ResponseWriter, _ *http.Request) {
|
||||||
|
rw.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
server = httptest.NewServer(mux)
|
||||||
|
serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String())
|
||||||
|
})
|
||||||
|
|
||||||
|
var _ = AfterSuite(func() {
|
||||||
|
server.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
// testHTTPRequest is a struct used to capture the state of a request made to
|
||||||
|
// the test server
|
||||||
|
type testHTTPRequest struct {
|
||||||
|
Method string
|
||||||
|
Header http.Header
|
||||||
|
Body []byte
|
||||||
|
RequestURI string
|
||||||
|
}
|
||||||
|
|
||||||
|
type testHTTPUpstream struct{}
|
||||||
|
|
||||||
|
func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
request, err := toTestHTTPRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.writeError(rw, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
t.writeError(rw, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rw.Header().Set("Content-Type", "application/json")
|
||||||
|
rw.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) {
|
||||||
|
rw.WriteHeader(500)
|
||||||
|
if err != nil {
|
||||||
|
rw.Write([]byte(err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) {
|
||||||
|
requestBody := []byte{}
|
||||||
|
if req.Body != http.NoBody {
|
||||||
|
var err error
|
||||||
|
requestBody, err = ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return testHTTPRequest{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return testHTTPRequest{
|
||||||
|
Method: req.Method,
|
||||||
|
Header: req.Header,
|
||||||
|
Body: requestBody,
|
||||||
|
RequestURI: req.RequestURI,
|
||||||
|
}, nil
|
||||||
|
}
|
@ -1,136 +0,0 @@
|
|||||||
package requests
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/bitly/go-simplejson"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testBackend(t *testing.T, responseCode int, payload string) *httptest.Server {
|
|
||||||
return httptest.NewServer(http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(responseCode)
|
|
||||||
_, err := w.Write([]byte(payload))
|
|
||||||
require.NoError(t, err)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequest(t *testing.T) {
|
|
||||||
backend := testBackend(t, 200, "{\"foo\": \"bar\"}")
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", backend.URL, nil)
|
|
||||||
response, err := Request(req)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
result, err := response.Get("foo").String()
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
assert.Equal(t, "bar", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestFailure(t *testing.T) {
|
|
||||||
// Create a backend to generate a test URL, then close it to cause a
|
|
||||||
// connection error.
|
|
||||||
backend := testBackend(t, 200, "{\"foo\": \"bar\"}")
|
|
||||||
backend.Close()
|
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", backend.URL, nil)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
resp, err := Request(req)
|
|
||||||
assert.Equal(t, (*simplejson.Json)(nil), resp)
|
|
||||||
assert.NotEqual(t, nil, err)
|
|
||||||
if !strings.Contains(err.Error(), "refused") {
|
|
||||||
t.Error("expected error when a connection fails: ", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHttpErrorCode(t *testing.T) {
|
|
||||||
backend := testBackend(t, 404, "{\"foo\": \"bar\"}")
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", backend.URL, nil)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
resp, err := Request(req)
|
|
||||||
assert.Equal(t, (*simplejson.Json)(nil), resp)
|
|
||||||
assert.NotEqual(t, nil, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestJsonParsingError(t *testing.T) {
|
|
||||||
backend := testBackend(t, 200, "not well-formed JSON")
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", backend.URL, nil)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
resp, err := Request(req)
|
|
||||||
assert.Equal(t, (*simplejson.Json)(nil), resp)
|
|
||||||
assert.NotEqual(t, nil, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parsing a URL practically never fails, so we won't cover that test case.
|
|
||||||
func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) {
|
|
||||||
backend := httptest.NewServer(http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
token := r.FormValue("access_token")
|
|
||||||
if r.URL.Path == "/" && token == "my_token" {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
_, err := w.Write([]byte("some payload"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(403)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
response, err := RequestUnparsedResponse(
|
|
||||||
context.Background(), backend.URL+"?access_token=my_token", nil)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
defer response.Body.Close()
|
|
||||||
|
|
||||||
assert.Equal(t, 200, response.StatusCode)
|
|
||||||
body, err := ioutil.ReadAll(response.Body)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
assert.Equal(t, "some payload", string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) {
|
|
||||||
backend := testBackend(t, 200, "some payload")
|
|
||||||
// Close the backend now to force a request failure.
|
|
||||||
backend.Close()
|
|
||||||
|
|
||||||
response, err := RequestUnparsedResponse(
|
|
||||||
context.Background(), backend.URL+"?access_token=my_token", nil)
|
|
||||||
assert.NotEqual(t, nil, err)
|
|
||||||
assert.Equal(t, (*http.Response)(nil), response)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestUnparsedResponseUsingHeaders(t *testing.T) {
|
|
||||||
backend := httptest.NewServer(http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
_, err := w.Write([]byte("some payload"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(403)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer backend.Close()
|
|
||||||
|
|
||||||
headers := make(http.Header)
|
|
||||||
headers.Set("Auth", "my_token")
|
|
||||||
response, err := RequestUnparsedResponse(context.Background(), backend.URL, headers)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
defer response.Body.Close()
|
|
||||||
|
|
||||||
assert.Equal(t, 200, response.StatusCode)
|
|
||||||
body, err := ioutil.ReadAll(response.Body)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "some payload", string(body))
|
|
||||||
}
|
|
98
pkg/requests/result.go
Normal file
98
pkg/requests/result.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
package requests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitly/go-simplejson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Result is the result of a request created by a Builder
|
||||||
|
type Result interface {
|
||||||
|
Error() error
|
||||||
|
StatusCode() int
|
||||||
|
Headers() http.Header
|
||||||
|
Body() []byte
|
||||||
|
UnmarshalInto(interface{}) error
|
||||||
|
UnmarshalJSON() (*simplejson.Json, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
err error
|
||||||
|
response *http.Response
|
||||||
|
body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns an error from the result if present
|
||||||
|
func (r *result) Error() error {
|
||||||
|
return r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusCode returns the response's status code
|
||||||
|
func (r *result) StatusCode() int {
|
||||||
|
if r.response != nil {
|
||||||
|
return r.response.StatusCode
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Headers returns the response's headers
|
||||||
|
func (r *result) Headers() http.Header {
|
||||||
|
if r.response != nil {
|
||||||
|
return r.response.Header
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Body returns the response's body
|
||||||
|
func (r *result) Body() []byte {
|
||||||
|
return r.body
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalInto attempts to unmarshal the response into the the given interface.
|
||||||
|
// The response body is assumed to be JSON.
|
||||||
|
// The response must have a 200 status otherwise an error will be returned.
|
||||||
|
func (r *result) UnmarshalInto(into interface{}) error {
|
||||||
|
body, err := r.getBodyForUnmarshal()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, into); err != nil {
|
||||||
|
return fmt.Errorf("error unmarshalling body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON performs the request and attempts to unmarshal the response into a
|
||||||
|
// simplejson.Json. The response body is assume to be JSON.
|
||||||
|
// The response must have a 200 status otherwise an error will be returned.
|
||||||
|
func (r *result) UnmarshalJSON() (*simplejson.Json, error) {
|
||||||
|
body, err := r.getBodyForUnmarshal()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := simplejson.NewJson(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error reading json: %v", err)
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBodyForUnmarshal returns the body if there wasn't an error and the status
|
||||||
|
// code was 200.
|
||||||
|
func (r *result) getBodyForUnmarshal() ([]byte, error) {
|
||||||
|
if r.Error() != nil {
|
||||||
|
return nil, r.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only unmarshal body if the response was successful
|
||||||
|
if r.StatusCode() != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("unexpected status \"%d\": %s", r.StatusCode(), r.Body())
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.Body(), nil
|
||||||
|
}
|
326
pkg/requests/result_test.go
Normal file
326
pkg/requests/result_test.go
Normal file
@ -0,0 +1,326 @@
|
|||||||
|
package requests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Result suite", func() {
|
||||||
|
Context("with a result", func() {
|
||||||
|
type resultTableInput struct {
|
||||||
|
result Result
|
||||||
|
expectedError error
|
||||||
|
expectedStatusCode int
|
||||||
|
expectedHeaders http.Header
|
||||||
|
expectedBody []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("accessors should return expected results",
|
||||||
|
func(in resultTableInput) {
|
||||||
|
if in.expectedError != nil {
|
||||||
|
Expect(in.result.Error()).To(MatchError(in.expectedError))
|
||||||
|
} else {
|
||||||
|
Expect(in.result.Error()).To(BeNil())
|
||||||
|
}
|
||||||
|
|
||||||
|
Expect(in.result.StatusCode()).To(Equal(in.expectedStatusCode))
|
||||||
|
Expect(in.result.Headers()).To(Equal(in.expectedHeaders))
|
||||||
|
Expect(in.result.Body()).To(Equal(in.expectedBody))
|
||||||
|
},
|
||||||
|
Entry("with an empty result", resultTableInput{
|
||||||
|
result: &result{},
|
||||||
|
expectedError: nil,
|
||||||
|
expectedStatusCode: 0,
|
||||||
|
expectedHeaders: nil,
|
||||||
|
expectedBody: nil,
|
||||||
|
}),
|
||||||
|
Entry("with an error", resultTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: errors.New("error"),
|
||||||
|
},
|
||||||
|
expectedError: errors.New("error"),
|
||||||
|
expectedStatusCode: 0,
|
||||||
|
expectedHeaders: nil,
|
||||||
|
expectedBody: nil,
|
||||||
|
}),
|
||||||
|
Entry("with a response with no headers", resultTableInput{
|
||||||
|
result: &result{
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusTeapot,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: nil,
|
||||||
|
expectedStatusCode: http.StatusTeapot,
|
||||||
|
expectedHeaders: nil,
|
||||||
|
expectedBody: nil,
|
||||||
|
}),
|
||||||
|
Entry("with a response with no status code", resultTableInput{
|
||||||
|
result: &result{
|
||||||
|
response: &http.Response{
|
||||||
|
Header: http.Header{
|
||||||
|
"foo": []string{"bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: nil,
|
||||||
|
expectedStatusCode: 0,
|
||||||
|
expectedHeaders: http.Header{
|
||||||
|
"foo": []string{"bar"},
|
||||||
|
},
|
||||||
|
expectedBody: nil,
|
||||||
|
}),
|
||||||
|
Entry("with a response with a body", resultTableInput{
|
||||||
|
result: &result{
|
||||||
|
body: []byte("some body"),
|
||||||
|
},
|
||||||
|
expectedError: nil,
|
||||||
|
expectedStatusCode: 0,
|
||||||
|
expectedHeaders: nil,
|
||||||
|
expectedBody: []byte("some body"),
|
||||||
|
}),
|
||||||
|
Entry("with all fields", resultTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: errors.New("some error"),
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusFound,
|
||||||
|
Header: http.Header{
|
||||||
|
"header": []string{"value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
body: []byte("a body"),
|
||||||
|
},
|
||||||
|
expectedError: errors.New("some error"),
|
||||||
|
expectedStatusCode: http.StatusFound,
|
||||||
|
expectedHeaders: http.Header{
|
||||||
|
"header": []string{"value"},
|
||||||
|
},
|
||||||
|
expectedBody: []byte("a body"),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("UnmarshalInto", func() {
|
||||||
|
type testStruct struct {
|
||||||
|
A string `json:"a"`
|
||||||
|
B int `json:"b"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type unmarshalIntoTableInput struct {
|
||||||
|
result Result
|
||||||
|
expectedErr error
|
||||||
|
expectedOutput *testStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("with a result",
|
||||||
|
func(in unmarshalIntoTableInput) {
|
||||||
|
input := &testStruct{}
|
||||||
|
err := in.result.UnmarshalInto(input)
|
||||||
|
if in.expectedErr != nil {
|
||||||
|
Expect(err).To(MatchError(in.expectedErr))
|
||||||
|
} else {
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
Expect(input).To(Equal(in.expectedOutput))
|
||||||
|
},
|
||||||
|
Entry("with an error", unmarshalIntoTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: errors.New("got an error"),
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
body: []byte("{\"a\": \"foo\"}"),
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("got an error"),
|
||||||
|
expectedOutput: &testStruct{},
|
||||||
|
}),
|
||||||
|
Entry("with a 409 status code", unmarshalIntoTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: nil,
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusConflict,
|
||||||
|
},
|
||||||
|
body: []byte("{\"a\": \"foo\"}"),
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("unexpected status \"409\": {\"a\": \"foo\"}"),
|
||||||
|
expectedOutput: &testStruct{},
|
||||||
|
}),
|
||||||
|
Entry("when the response has a valid json response", unmarshalIntoTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: nil,
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
body: []byte("{\"a\": \"foo\", \"b\": 1}"),
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectedOutput: &testStruct{A: "foo", B: 1},
|
||||||
|
}),
|
||||||
|
Entry("when the response body is empty", unmarshalIntoTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: nil,
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
body: []byte(""),
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error unmarshalling body: unexpected end of JSON input"),
|
||||||
|
expectedOutput: &testStruct{},
|
||||||
|
}),
|
||||||
|
Entry("when the response body is not json", unmarshalIntoTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: nil,
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
body: []byte("not json"),
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error unmarshalling body: invalid character 'o' in literal null (expecting 'u')"),
|
||||||
|
expectedOutput: &testStruct{},
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("UnmarshalJSON", func() {
|
||||||
|
type testStruct struct {
|
||||||
|
A string `json:"a"`
|
||||||
|
B int `json:"b"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type unmarshalJSONTableInput struct {
|
||||||
|
result Result
|
||||||
|
expectedErr error
|
||||||
|
expectedOutput *testStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("with a result",
|
||||||
|
func(in unmarshalJSONTableInput) {
|
||||||
|
j, err := in.result.UnmarshalJSON()
|
||||||
|
if in.expectedErr != nil {
|
||||||
|
Expect(err).To(MatchError(in.expectedErr))
|
||||||
|
Expect(j).To(BeNil())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// No error so j should not be nil
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
input := &testStruct{
|
||||||
|
A: j.Get("a").MustString(),
|
||||||
|
B: j.Get("b").MustInt(),
|
||||||
|
}
|
||||||
|
Expect(input).To(Equal(in.expectedOutput))
|
||||||
|
},
|
||||||
|
Entry("with an error", unmarshalJSONTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: errors.New("got an error"),
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
body: []byte("{\"a\": \"foo\"}"),
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("got an error"),
|
||||||
|
expectedOutput: &testStruct{},
|
||||||
|
}),
|
||||||
|
Entry("with a 409 status code", unmarshalJSONTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: nil,
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusConflict,
|
||||||
|
},
|
||||||
|
body: []byte("{\"a\": \"foo\"}"),
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("unexpected status \"409\": {\"a\": \"foo\"}"),
|
||||||
|
expectedOutput: &testStruct{},
|
||||||
|
}),
|
||||||
|
Entry("when the response has a valid json response", unmarshalJSONTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: nil,
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
body: []byte("{\"a\": \"foo\", \"b\": 1}"),
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectedOutput: &testStruct{A: "foo", B: 1},
|
||||||
|
}),
|
||||||
|
Entry("when the response body is empty", unmarshalJSONTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: nil,
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
body: []byte(""),
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error reading json: EOF"),
|
||||||
|
expectedOutput: &testStruct{},
|
||||||
|
}),
|
||||||
|
Entry("when the response body is not json", unmarshalJSONTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: nil,
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
body: []byte("not json"),
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("error reading json: invalid character 'o' in literal null (expecting 'u')"),
|
||||||
|
expectedOutput: &testStruct{},
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("getBodyForUnmarshal", func() {
|
||||||
|
type getBodyForUnmarshalTableInput struct {
|
||||||
|
result *result
|
||||||
|
expectedErr error
|
||||||
|
expectedBody []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("when getting the body", func(in getBodyForUnmarshalTableInput) {
|
||||||
|
body, err := in.result.getBodyForUnmarshal()
|
||||||
|
if in.expectedErr != nil {
|
||||||
|
Expect(err).To(MatchError(in.expectedErr))
|
||||||
|
} else {
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
Expect(body).To(Equal(in.expectedBody))
|
||||||
|
},
|
||||||
|
Entry("when the result has an error", getBodyForUnmarshalTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: errors.New("got an error"),
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
body: []byte("body"),
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("got an error"),
|
||||||
|
expectedBody: nil,
|
||||||
|
}),
|
||||||
|
Entry("when the response has a 409 status code", getBodyForUnmarshalTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: nil,
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusConflict,
|
||||||
|
},
|
||||||
|
body: []byte("body"),
|
||||||
|
},
|
||||||
|
expectedErr: errors.New("unexpected status \"409\": body"),
|
||||||
|
expectedBody: nil,
|
||||||
|
}),
|
||||||
|
Entry("when the response has a 200 status code", getBodyForUnmarshalTableInput{
|
||||||
|
result: &result{
|
||||||
|
err: nil,
|
||||||
|
response: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
body: []byte("body"),
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
expectedBody: []byte("body"),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
@ -83,9 +83,14 @@ func Validate(o *options.Options) error {
|
|||||||
|
|
||||||
logger.Printf("Performing OIDC Discovery...")
|
logger.Printf("Performing OIDC Discovery...")
|
||||||
|
|
||||||
if req, err := http.NewRequest("GET", strings.TrimSuffix(o.OIDCIssuerURL, "/")+"/.well-known/openid-configuration", nil); err == nil {
|
requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration"
|
||||||
if body, err := requests.Request(req); err == nil {
|
body, err := requests.New(requestURL).
|
||||||
|
WithContext(ctx).
|
||||||
|
Do().
|
||||||
|
UnmarshalJSON()
|
||||||
|
if err != nil {
|
||||||
|
logger.Printf("error: failed to discover OIDC configuration: %v", err)
|
||||||
|
} else {
|
||||||
// Prefer manually configured URLs. It's a bit unclear
|
// Prefer manually configured URLs. It's a bit unclear
|
||||||
// why you'd be doing discovery and also providing the URLs
|
// why you'd be doing discovery and also providing the URLs
|
||||||
// explicitly though...
|
// explicitly though...
|
||||||
@ -106,11 +111,6 @@ func Validate(o *options.Options) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
o.SkipOIDCDiscovery = true
|
o.SkipOIDCDiscovery = true
|
||||||
} else {
|
|
||||||
logger.Printf("error: failed to discover OIDC configuration: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Printf("error: failed parsing OIDC discovery URL: %v", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -385,10 +385,10 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// Try as JWKS URI
|
// Try as JWKS URI
|
||||||
jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json"
|
jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json"
|
||||||
_, err := http.NewRequest("GET", jwksURI, nil)
|
if err := requests.New(jwksURI).Do().Error(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config)
|
verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config)
|
||||||
} else {
|
} else {
|
||||||
verifier = provider.Verifier(config)
|
verifier = provider.Verifier(config)
|
||||||
|
@ -3,10 +3,8 @@ package providers
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
@ -91,39 +89,22 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||||||
params.Add("resource", p.ProtectedResource.String())
|
params.Add("resource", p.ProtectedResource.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
var req *http.Request
|
|
||||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
var resp *http.Response
|
|
||||||
resp, err = http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var body []byte
|
|
||||||
body, err = ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var jsonResponse struct {
|
var jsonResponse struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
ExpiresOn int64 `json:"expires_on,string"`
|
ExpiresOn int64 `json:"expires_on,string"`
|
||||||
IDToken string `json:"id_token"`
|
IDToken string `json:"id_token"`
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &jsonResponse)
|
|
||||||
|
err = requests.New(p.RedeemURL.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithMethod("POST").
|
||||||
|
WithBody(bytes.NewBufferString(params.Encode())).
|
||||||
|
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&jsonResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
@ -169,26 +150,22 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session
|
|||||||
if s.AccessToken == "" {
|
if s.AccessToken == "" {
|
||||||
return "", errors.New("missing access token")
|
return "", errors.New("missing access token")
|
||||||
}
|
}
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
req.Header = getAzureHeader(s.AccessToken)
|
|
||||||
|
|
||||||
json, err := requests.Request(req)
|
|
||||||
|
|
||||||
|
json, err := requests.New(p.ProfileURL.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithHeaders(getAzureHeader(s.AccessToken)).
|
||||||
|
Do().
|
||||||
|
UnmarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
email, err = getEmailFromJSON(json)
|
email, err = getEmailFromJSON(json)
|
||||||
|
|
||||||
if err == nil && email != "" {
|
if err == nil && email != "" {
|
||||||
return email, err
|
return email, err
|
||||||
}
|
}
|
||||||
|
|
||||||
email, err = json.Get("userPrincipalName").String()
|
email, err = json.Get("userPrincipalName").String()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed making request %s", err)
|
logger.Printf("failed making request %s", err)
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -2,7 +2,6 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -85,15 +84,14 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
|||||||
FullName string `json:"full_name"`
|
FullName string `json:"full_name"`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET",
|
|
||||||
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken
|
||||||
|
err := requests.New(requestURL).
|
||||||
|
WithContext(ctx).
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&emails)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed building request %s", err)
|
logger.Printf("failed making request: %v", err)
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
err = requests.RequestJSON(req, &emails)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("failed making request %s", err)
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,15 +99,15 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
|||||||
teamURL := &url.URL{}
|
teamURL := &url.URL{}
|
||||||
*teamURL = *p.ValidateURL
|
*teamURL = *p.ValidateURL
|
||||||
teamURL.Path = "/2.0/teams"
|
teamURL.Path = "/2.0/teams"
|
||||||
req, err = http.NewRequestWithContext(ctx, "GET",
|
|
||||||
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil)
|
requestURL := teamURL.String() + "?role=member&access_token=" + s.AccessToken
|
||||||
|
|
||||||
|
err := requests.New(requestURL).
|
||||||
|
WithContext(ctx).
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&teams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed building request %s", err)
|
logger.Printf("failed requesting teams membership: %v", err)
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
err = requests.RequestJSON(req, &teams)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("failed requesting teams membership %s", err)
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
var found = false
|
var found = false
|
||||||
@ -129,20 +127,20 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses
|
|||||||
repositoriesURL := &url.URL{}
|
repositoriesURL := &url.URL{}
|
||||||
*repositoriesURL = *p.ValidateURL
|
*repositoriesURL = *p.ValidateURL
|
||||||
repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0]
|
repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0]
|
||||||
req, err = http.NewRequestWithContext(ctx, "GET",
|
|
||||||
repositoriesURL.String()+"?role=contributor"+
|
requestURL := repositoriesURL.String() + "?role=contributor" +
|
||||||
"&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+
|
"&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") +
|
||||||
"&access_token="+s.AccessToken,
|
"&access_token=" + s.AccessToken
|
||||||
nil)
|
|
||||||
|
err := requests.New(requestURL).
|
||||||
|
WithContext(ctx).
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&repositories)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed building request %s", err)
|
logger.Printf("failed checking repository access: %v", err)
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
err = requests.RequestJSON(req, &repositories)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("failed checking repository access %s", err)
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
var found = false
|
var found = false
|
||||||
for _, repository := range repositories.Values {
|
for _, repository := range repositories.Values {
|
||||||
if p.Repository == repository.FullName {
|
if p.Repository == repository.FullName {
|
||||||
|
@ -60,13 +60,12 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.
|
|||||||
if s.AccessToken == "" {
|
if s.AccessToken == "" {
|
||||||
return "", errors.New("missing access token")
|
return "", errors.New("missing access token")
|
||||||
}
|
}
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
req.Header = getDigitalOceanHeader(s.AccessToken)
|
|
||||||
|
|
||||||
json, err := requests.Request(req)
|
json, err := requests.New(p.ProfileURL.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithHeaders(getDigitalOceanHeader(s.AccessToken)).
|
||||||
|
Do().
|
||||||
|
UnmarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -62,20 +62,22 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
|||||||
if s.AccessToken == "" {
|
if s.AccessToken == "" {
|
||||||
return "", errors.New("missing access token")
|
return "", errors.New("missing access token")
|
||||||
}
|
}
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
req.Header = getFacebookHeader(s.AccessToken)
|
|
||||||
|
|
||||||
type result struct {
|
type result struct {
|
||||||
Email string
|
Email string
|
||||||
}
|
}
|
||||||
var r result
|
var r result
|
||||||
err = requests.RequestJSON(req, &r)
|
|
||||||
|
requestURL := p.ProfileURL.String() + "?fields=name,email"
|
||||||
|
err := requests.New(requestURL).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithHeaders(getFacebookHeader(s.AccessToken)).
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Email == "" {
|
if r.Email == "" {
|
||||||
return "", errors.New("no email")
|
return "", errors.New("no email")
|
||||||
}
|
}
|
||||||
|
@ -2,10 +2,8 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
@ -15,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GitHubProvider represents an GitHub based Identity Provider
|
// GitHubProvider represents an GitHub based Identity Provider
|
||||||
@ -111,27 +110,17 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool,
|
|||||||
Path: path.Join(p.ValidateURL.Path, "/user/orgs"),
|
Path: path.Join(p.ValidateURL.Path, "/user/orgs"),
|
||||||
RawQuery: params.Encode(),
|
RawQuery: params.Encode(),
|
||||||
}
|
}
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
|
||||||
req.Header = getGitHubHeader(accessToken)
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return false, fmt.Errorf(
|
|
||||||
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
|
||||||
}
|
|
||||||
|
|
||||||
var op orgsPage
|
var op orgsPage
|
||||||
if err := json.Unmarshal(body, &op); err != nil {
|
err := requests.New(endpoint.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithHeaders(getGitHubHeader(accessToken)).
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&op)
|
||||||
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(op) == 0 {
|
if len(op) == 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -187,11 +176,15 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
|
|||||||
RawQuery: params.Encode(),
|
RawQuery: params.Encode(),
|
||||||
}
|
}
|
||||||
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
// bodyclose cannot detect that the body is being closed later in requests.Into,
|
||||||
req.Header = getGitHubHeader(accessToken)
|
// so have to skip the linting for the next line.
|
||||||
resp, err := http.DefaultClient.Do(req)
|
// nolint:bodyclose
|
||||||
if err != nil {
|
result := requests.New(endpoint.String()).
|
||||||
return false, err
|
WithContext(ctx).
|
||||||
|
WithHeaders(getGitHubHeader(accessToken)).
|
||||||
|
Do()
|
||||||
|
if result.Error() != nil {
|
||||||
|
return false, result.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
if last == 0 {
|
if last == 0 {
|
||||||
@ -207,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
|
|||||||
// link header at last page (doesn't exist last info)
|
// link header at last page (doesn't exist last info)
|
||||||
// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
|
// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
|
||||||
|
|
||||||
link := resp.Header.Get("Link")
|
link := result.Headers().Get("Link")
|
||||||
rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`)
|
rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`)
|
||||||
i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1"))
|
i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1"))
|
||||||
|
|
||||||
@ -217,21 +210,9 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
resp.Body.Close()
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return false, fmt.Errorf(
|
|
||||||
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
|
||||||
}
|
|
||||||
|
|
||||||
var tp teamsPage
|
var tp teamsPage
|
||||||
if err := json.Unmarshal(body, &tp); err != nil {
|
if err := result.UnmarshalInto(&tp); err != nil {
|
||||||
return false, fmt.Errorf("%s unmarshaling %s", err, body)
|
return false, err
|
||||||
}
|
}
|
||||||
if len(tp) == 0 {
|
if len(tp) == 0 {
|
||||||
break
|
break
|
||||||
@ -297,25 +278,13 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool,
|
|||||||
Path: path.Join(p.ValidateURL.Path, "/repo/", p.Repo),
|
Path: path.Join(p.ValidateURL.Path, "/repo/", p.Repo),
|
||||||
}
|
}
|
||||||
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
|
||||||
req.Header = getGitHubHeader(accessToken)
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return false, fmt.Errorf(
|
|
||||||
"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
|
||||||
}
|
|
||||||
|
|
||||||
var repo repository
|
var repo repository
|
||||||
if err := json.Unmarshal(body, &repo); err != nil {
|
err := requests.New(endpoint.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithHeaders(getGitHubHeader(accessToken)).
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&repo)
|
||||||
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -337,26 +306,15 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool,
|
|||||||
Host: p.ValidateURL.Host,
|
Host: p.ValidateURL.Host,
|
||||||
Path: path.Join(p.ValidateURL.Path, "/user"),
|
Path: path.Join(p.ValidateURL.Path, "/user"),
|
||||||
}
|
}
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
|
||||||
req.Header = getGitHubHeader(accessToken)
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
err := requests.New(endpoint.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithHeaders(getGitHubHeader(accessToken)).
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return false, fmt.Errorf("got %d from %q %s",
|
|
||||||
resp.StatusCode, stripToken(endpoint.String()), body)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(body, &user); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.isVerifiedUser(user.Login) {
|
if p.isVerifiedUser(user.Login) {
|
||||||
return true, nil
|
return true, nil
|
||||||
@ -372,24 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok
|
|||||||
Host: p.ValidateURL.Host,
|
Host: p.ValidateURL.Host,
|
||||||
Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username),
|
Path: path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username),
|
||||||
}
|
}
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
result := requests.New(endpoint.String()).
|
||||||
req.Header = getGitHubHeader(accessToken)
|
WithContext(ctx).
|
||||||
resp, err := http.DefaultClient.Do(req)
|
WithHeaders(getGitHubHeader(accessToken)).
|
||||||
if err != nil {
|
Do()
|
||||||
return false, err
|
if result.Error() != nil {
|
||||||
}
|
return false, result.Error()
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 204 {
|
if result.StatusCode() != 204 {
|
||||||
return false, fmt.Errorf("got %d from %q %s",
|
return false, fmt.Errorf("got %d from %q %s",
|
||||||
resp.StatusCode, endpoint.String(), body)
|
result.StatusCode(), endpoint.String(), result.Body())
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body())
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
@ -440,28 +394,14 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
|
|||||||
Host: p.ValidateURL.Host,
|
Host: p.ValidateURL.Host,
|
||||||
Path: path.Join(p.ValidateURL.Path, "/user/emails"),
|
Path: path.Join(p.ValidateURL.Path, "/user/emails"),
|
||||||
}
|
}
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
err := requests.New(endpoint.String()).
|
||||||
req.Header = getGitHubHeader(s.AccessToken)
|
WithContext(ctx).
|
||||||
resp, err := http.DefaultClient.Do(req)
|
WithHeaders(getGitHubHeader(s.AccessToken)).
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&emails)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return "", fmt.Errorf("got %d from %q %s",
|
|
||||||
resp.StatusCode, endpoint.String(), body)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
|
||||||
|
|
||||||
if err := json.Unmarshal(body, &emails); err != nil {
|
|
||||||
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
returnEmail := ""
|
returnEmail := ""
|
||||||
for _, email := range emails {
|
for _, email := range emails {
|
||||||
@ -489,34 +429,15 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta
|
|||||||
Path: path.Join(p.ValidateURL.Path, "/user"),
|
Path: path.Join(p.ValidateURL.Path, "/user"),
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil)
|
err := requests.New(endpoint.String()).
|
||||||
if err != nil {
|
WithContext(ctx).
|
||||||
return "", fmt.Errorf("could not create new GET request: %v", err)
|
WithHeaders(getGitHubHeader(s.AccessToken)).
|
||||||
}
|
Do().
|
||||||
|
UnmarshalInto(&user)
|
||||||
req.Header = getGitHubHeader(s.AccessToken)
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
defer resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return "", fmt.Errorf("got %d from %q %s",
|
|
||||||
resp.StatusCode, endpoint.String(), body)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
|
||||||
|
|
||||||
if err := json.Unmarshal(body, &user); err != nil {
|
|
||||||
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we have the username we can check collaborator status
|
// Now that we have the username we can check collaborator status
|
||||||
if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" {
|
if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" {
|
||||||
if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok {
|
if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok {
|
||||||
|
@ -2,15 +2,13 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
oidc "github.com/coreos/go-oidc"
|
oidc "github.com/coreos/go-oidc"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -131,31 +129,14 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta
|
|||||||
userInfoURL := *p.LoginURL
|
userInfoURL := *p.LoginURL
|
||||||
userInfoURL.Path = "/oauth/userinfo"
|
userInfoURL.Path = "/oauth/userinfo"
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create user info request: %v", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to perform user info request: %v", err)
|
|
||||||
}
|
|
||||||
var body []byte
|
|
||||||
body, err = ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read user info response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return nil, fmt.Errorf("got %d during user info request: %s", resp.StatusCode, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
var userInfo gitlabUserInfo
|
var userInfo gitlabUserInfo
|
||||||
err = json.Unmarshal(body, &userInfo)
|
err := requests.New(userInfoURL.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
SetHeader("Authorization", "Bearer "+s.AccessToken).
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&userInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse user info: %v", err)
|
return nil, fmt.Errorf("error getting user info: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &userInfo, nil
|
return &userInfo, nil
|
||||||
|
@ -9,13 +9,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
admin "google.golang.org/api/admin/directory/v1"
|
admin "google.golang.org/api/admin/directory/v1"
|
||||||
"google.golang.org/api/googleapi"
|
"google.golang.org/api/googleapi"
|
||||||
@ -116,28 +116,6 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
|||||||
params.Add("client_secret", clientSecret)
|
params.Add("client_secret", clientSecret)
|
||||||
params.Add("code", code)
|
params.Add("code", code)
|
||||||
params.Add("grant_type", "authorization_code")
|
params.Add("grant_type", "authorization_code")
|
||||||
var req *http.Request
|
|
||||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var body []byte
|
|
||||||
body, err = ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var jsonResponse struct {
|
var jsonResponse struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
@ -145,10 +123,18 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
|||||||
ExpiresIn int64 `json:"expires_in"`
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
IDToken string `json:"id_token"`
|
IDToken string `json:"id_token"`
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &jsonResponse)
|
|
||||||
|
err = requests.New(p.RedeemURL.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithMethod("POST").
|
||||||
|
WithBody(bytes.NewBufferString(params.Encode())).
|
||||||
|
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&jsonResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := claimsFromIDToken(jsonResponse.IDToken)
|
c, err := claimsFromIDToken(jsonResponse.IDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -283,38 +269,24 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st
|
|||||||
params.Add("client_secret", clientSecret)
|
params.Add("client_secret", clientSecret)
|
||||||
params.Add("refresh_token", refreshToken)
|
params.Add("refresh_token", refreshToken)
|
||||||
params.Add("grant_type", "refresh_token")
|
params.Add("grant_type", "refresh_token")
|
||||||
var req *http.Request
|
|
||||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var body []byte
|
|
||||||
body, err = ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var data struct {
|
var data struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
IDToken string `json:"id_token"`
|
IDToken string `json:"id_token"`
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &data)
|
|
||||||
|
err = requests.New(p.RedeemURL.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithMethod("POST").
|
||||||
|
WithBody(bytes.NewBufferString(params.Encode())).
|
||||||
|
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return "", "", 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
token = data.AccessToken
|
token = data.AccessToken
|
||||||
idToken = data.IDToken
|
idToken = data.IDToken
|
||||||
expires = time.Duration(data.ExpiresIn) * time.Second
|
expires = time.Duration(data.ExpiresIn) * time.Second
|
||||||
|
@ -2,7 +2,6 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
@ -56,20 +55,22 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h
|
|||||||
params := url.Values{"access_token": {accessToken}}
|
params := url.Values{"access_token": {accessToken}}
|
||||||
endpoint = endpoint + "?" + params.Encode()
|
endpoint = endpoint + "?" + params.Encode()
|
||||||
}
|
}
|
||||||
resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header)
|
|
||||||
if err != nil {
|
result := requests.New(endpoint).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithHeaders(header).
|
||||||
|
Do()
|
||||||
|
if result.Error() != nil {
|
||||||
logger.Printf("GET %s", stripToken(endpoint))
|
logger.Printf("GET %s", stripToken(endpoint))
|
||||||
logger.Printf("token validation request failed: %s", err)
|
logger.Printf("token validation request failed: %s", result.Error())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ := ioutil.ReadAll(resp.Body)
|
logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body())
|
||||||
resp.Body.Close()
|
|
||||||
logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body)
|
|
||||||
|
|
||||||
if resp.StatusCode == 200 {
|
if result.StatusCode() == 200 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body)
|
logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
@ -51,14 +50,11 @@ func (p *KeycloakProvider) SetGroup(group string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
||||||
|
json, err := requests.New(p.ValidateURL.String()).
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil)
|
WithContext(ctx).
|
||||||
req.Header.Set("Authorization", "Bearer "+s.AccessToken)
|
SetHeader("Authorization", "Bearer "+s.AccessToken).
|
||||||
if err != nil {
|
Do().
|
||||||
logger.Printf("failed building request %s", err)
|
UnmarshalJSON()
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
json, err := requests.Request(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed making request %s", err)
|
logger.Printf("failed making request %s", err)
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -58,13 +58,13 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess
|
|||||||
if s.AccessToken == "" {
|
if s.AccessToken == "" {
|
||||||
return "", errors.New("missing access token")
|
return "", errors.New("missing access token")
|
||||||
}
|
}
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
req.Header = getLinkedInHeader(s.AccessToken)
|
|
||||||
|
|
||||||
json, err := requests.Request(req)
|
requestURL := p.ProfileURL.String() + "?format=json"
|
||||||
|
json, err := requests.New(requestURL).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithHeaders(getLinkedInHeader(s.AccessToken)).
|
||||||
|
Do().
|
||||||
|
UnmarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
"github.com/dgrijalva/jwt-go"
|
"github.com/dgrijalva/jwt-go"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,51 +129,34 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) {
|
func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) {
|
||||||
// query the user info endpoint for user attributes
|
|
||||||
var req *http.Request
|
|
||||||
req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var body []byte
|
|
||||||
body, err = ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse the user attributes from the data we got and make sure that
|
// parse the user attributes from the data we got and make sure that
|
||||||
// the email address has been validated.
|
// the email address has been validated.
|
||||||
var emailData struct {
|
var emailData struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
EmailVerified bool `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &emailData)
|
|
||||||
|
// query the user info endpoint for user attributes
|
||||||
|
err := requests.New(userInfoEndpoint).
|
||||||
|
WithContext(ctx).
|
||||||
|
SetHeader("Authorization", "Bearer "+accessToken).
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&emailData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return "", err
|
||||||
}
|
}
|
||||||
if emailData.Email == "" {
|
|
||||||
err = fmt.Errorf("missing email")
|
email := emailData.Email
|
||||||
return
|
if email == "" {
|
||||||
|
return "", fmt.Errorf("missing email")
|
||||||
}
|
}
|
||||||
email = emailData.Email
|
|
||||||
if !emailData.EmailVerified {
|
if !emailData.EmailVerified {
|
||||||
err = fmt.Errorf("email %s not listed as verified", email)
|
return "", fmt.Errorf("email %s not listed as verified", email)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
return email, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
@ -201,30 +185,6 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
|||||||
params.Add("code", code)
|
params.Add("code", code)
|
||||||
params.Add("grant_type", "authorization_code")
|
params.Add("grant_type", "authorization_code")
|
||||||
|
|
||||||
var req *http.Request
|
|
||||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
var resp *http.Response
|
|
||||||
resp, err = http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var body []byte
|
|
||||||
body, err = ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the token from the body that we got from the token endpoint.
|
// Get the token from the body that we got from the token endpoint.
|
||||||
var jsonResponse struct {
|
var jsonResponse struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
@ -232,9 +192,15 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
|
|||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &jsonResponse)
|
err = requests.New(p.RedeemURL.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithMethod("POST").
|
||||||
|
WithBody(bytes.NewBufferString(params.Encode())).
|
||||||
|
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||||
|
Do().
|
||||||
|
UnmarshalInto(&jsonResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// check nonce here
|
// check nonce here
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,18 +30,15 @@ func getNextcloudHeader(accessToken string) http.Header {
|
|||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET",
|
json, err := requests.New(p.ValidateURL.String()).
|
||||||
p.ValidateURL.String(), nil)
|
WithContext(ctx).
|
||||||
|
WithHeaders(getNextcloudHeader(s.AccessToken)).
|
||||||
|
Do().
|
||||||
|
UnmarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed building request %s", err)
|
return "", fmt.Errorf("error making request: %v", err)
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
req.Header = getNextcloudHeader(s.AccessToken)
|
|
||||||
json, err := requests.Request(req)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("failed making request %s", err)
|
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
email, err := json.Get("ocs").Get("data").Get("email").String()
|
email, err := json.Get("ocs").Get("data").Get("email").String()
|
||||||
return email, err
|
return email, err
|
||||||
}
|
}
|
||||||
|
@ -256,13 +256,11 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.
|
|||||||
// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
|
// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
|
||||||
// contents at the profileURL contains the email.
|
// contents at the profileURL contains the email.
|
||||||
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
|
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil)
|
respJSON, err := requests.New(profileURL).
|
||||||
if err != nil {
|
WithContext(ctx).
|
||||||
return nil, err
|
WithHeaders(getOIDCHeader(accessToken)).
|
||||||
}
|
Do().
|
||||||
req.Header = getOIDCHeader(accessToken)
|
UnmarshalJSON()
|
||||||
|
|
||||||
respJSON, err := requests.Request(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -3,17 +3,15 @@ package providers
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc"
|
"github.com/coreos/go-oidc"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/requests"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ Provider = (*ProviderData)(nil)
|
var _ Provider = (*ProviderData)(nil)
|
||||||
@ -39,35 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||||||
params.Add("resource", p.ProtectedResource.String())
|
params.Add("resource", p.ProtectedResource.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
var req *http.Request
|
result := requests.New(p.RedeemURL.String()).
|
||||||
req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
WithContext(ctx).
|
||||||
if err != nil {
|
WithMethod("POST").
|
||||||
return
|
WithBody(bytes.NewBufferString(params.Encode())).
|
||||||
}
|
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
Do()
|
||||||
|
if result.Error() != nil {
|
||||||
var resp *http.Response
|
return nil, result.Error()
|
||||||
resp, err = http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var body []byte
|
|
||||||
body, err = ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// blindly try json and x-www-form-urlencoded
|
// blindly try json and x-www-form-urlencoded
|
||||||
var jsonResponse struct {
|
var jsonResponse struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &jsonResponse)
|
err = result.UnmarshalInto(&jsonResponse)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s = &sessions.SessionState{
|
s = &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
@ -76,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||||||
}
|
}
|
||||||
|
|
||||||
var v url.Values
|
var v url.Values
|
||||||
v, err = url.ParseQuery(string(body))
|
v, err = url.ParseQuery(string(result.Body()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -84,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||||||
created := time.Now()
|
created := time.Now()
|
||||||
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
|
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("no access token found %s", body)
|
err = fmt.Errorf("no access token found %s", result.Body())
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user