mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2024-11-24 08:52:25 +02:00
81 lines
2.6 KiB
Go
81 lines
2.6 KiB
Go
package upstream
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/justinas/alice"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter"
|
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
|
)
|
|
|
|
// newRewritePath creates a new middleware that will rewrite the request URI
|
|
// path before handing the request to the next server.
|
|
func newRewritePath(rewriteRegExp *regexp.Regexp, rewriteTarget string, writer pagewriter.Writer) alice.Constructor {
|
|
return func(next http.Handler) http.Handler {
|
|
return rewritePath(rewriteRegExp, rewriteTarget, writer, next)
|
|
}
|
|
}
|
|
|
|
// rewritePath uses the regexp to rewrite the request URI based on the provided
|
|
// rewriteTarget.
|
|
func rewritePath(rewriteRegExp *regexp.Regexp, rewriteTarget string, writer pagewriter.Writer, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
reqURL, err := url.ParseRequestURI(req.RequestURI)
|
|
if err != nil {
|
|
logger.Errorf("could not parse request URI: %v", err)
|
|
writer.WriteErrorPage(rw, pagewriter.ErrorPageOpts{
|
|
Status: http.StatusInternalServerError,
|
|
RequestID: middleware.GetRequestScope(req).RequestID,
|
|
AppError: fmt.Sprintf("Could not parse request URI: %v", err),
|
|
})
|
|
return
|
|
}
|
|
|
|
// Use the regex to rewrite the request path before proxying to the upstream.
|
|
newURI := rewriteRegExp.ReplaceAllString(reqURL.Path, rewriteTarget)
|
|
reqURL.Path, reqURL.RawQuery, err = splitPathAndQuery(reqURL.Query(), newURI)
|
|
if err != nil {
|
|
logger.Errorf("could not parse rewrite URI: %v", err)
|
|
writer.WriteErrorPage(rw, pagewriter.ErrorPageOpts{
|
|
Status: http.StatusInternalServerError,
|
|
RequestID: middleware.GetRequestScope(req).RequestID,
|
|
AppError: fmt.Sprintf("Could not parse rewrite URI: %v", err),
|
|
})
|
|
return
|
|
}
|
|
|
|
req.RequestURI = reqURL.String()
|
|
next.ServeHTTP(rw, req)
|
|
})
|
|
}
|
|
|
|
// splitPathAndQuery splits the rewritten path into the URL Path and the URL
|
|
// raw query. Any rewritten query values are appended to the original query
|
|
// values.
|
|
// This relies on the underlying URL library to encode the query string.
|
|
// For duplicate values it appends each as a separate value, e.g. ?foo=bar&foo=baz.
|
|
func splitPathAndQuery(originalQuery url.Values, raw string) (string, string, error) {
|
|
s := strings.SplitN(raw, "?", 2)
|
|
if len(s) == 1 {
|
|
return s[0], originalQuery.Encode(), nil
|
|
}
|
|
|
|
queryValues, err := url.ParseQuery(s[1])
|
|
if err != nil {
|
|
return "", "", nil
|
|
}
|
|
|
|
for key, values := range queryValues {
|
|
for _, value := range values {
|
|
originalQuery.Add(key, value)
|
|
}
|
|
}
|
|
|
|
return s[0], originalQuery.Encode(), nil
|
|
}
|