2019-05-16 10:39:25 -04:00
|
|
|
package web
|
|
|
|
|
|
|
|
import (
|
2019-07-31 13:47:30 -08:00
|
|
|
"context"
|
2019-05-16 10:39:25 -04:00
|
|
|
"encoding/json"
|
2019-07-12 11:41:41 -08:00
|
|
|
"net"
|
2019-05-16 10:39:25 -04:00
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
|
2019-08-06 17:28:13 -08:00
|
|
|
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror"
|
2019-06-24 17:36:42 -08:00
|
|
|
"github.com/gorilla/schema"
|
2019-06-25 22:31:54 -08:00
|
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/xwb1989/sqlparser"
|
|
|
|
"github.com/xwb1989/sqlparser/dependency/querypb"
|
2019-05-16 10:39:25 -04:00
|
|
|
)
|
|
|
|
|
2019-07-12 11:41:41 -08:00
|
|
|
// Headers
|
|
|
|
const (
|
2019-08-13 17:08:19 -08:00
|
|
|
HeaderAccept = "Accept"
|
2019-07-13 03:03:30 -08:00
|
|
|
HeaderUpgrade = "Upgrade"
|
2019-07-12 11:41:41 -08:00
|
|
|
HeaderXForwardedFor = "X-Forwarded-For"
|
|
|
|
HeaderXForwardedProto = "X-Forwarded-Proto"
|
|
|
|
HeaderXForwardedProtocol = "X-Forwarded-Protocol"
|
|
|
|
HeaderXForwardedSsl = "X-Forwarded-Ssl"
|
|
|
|
HeaderXUrlScheme = "X-Url-Scheme"
|
|
|
|
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
|
|
|
|
HeaderXRealIP = "X-Real-IP"
|
|
|
|
HeaderXRequestID = "X-Request-ID"
|
|
|
|
HeaderXRequestedWith = "X-Requested-With"
|
|
|
|
HeaderServer = "Server"
|
|
|
|
HeaderOrigin = "Origin"
|
|
|
|
)
|
|
|
|
|
2019-05-16 10:39:25 -04:00
|
|
|
// Decode reads the body of an HTTP request looking for a JSON document. The
|
|
|
|
// body is decoded into the provided value.
|
|
|
|
//
|
|
|
|
// If the provided value is a struct then it is checked for validation tags.
|
2019-07-31 13:47:30 -08:00
|
|
|
func Decode(ctx context.Context, r *http.Request, val interface{}) error {
|
2019-06-24 17:36:42 -08:00
|
|
|
|
2019-06-26 20:21:00 -08:00
|
|
|
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch || r.Method == http.MethodDelete {
|
2019-06-24 17:36:42 -08:00
|
|
|
decoder := json.NewDecoder(r.Body)
|
|
|
|
decoder.DisallowUnknownFields()
|
|
|
|
if err := decoder.Decode(val); err != nil {
|
2019-07-31 13:47:30 -08:00
|
|
|
return weberror.NewErrorMessage(ctx, err, http.StatusBadRequest, "decode request body failed")
|
2019-06-24 17:36:42 -08:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
decoder := schema.NewDecoder()
|
|
|
|
if err := decoder.Decode(val, r.URL.Query()); err != nil {
|
2019-06-26 20:21:00 -08:00
|
|
|
err = errors.Wrap(err, "decode request query failed")
|
2019-07-31 13:47:30 -08:00
|
|
|
return weberror.NewErrorMessage(ctx, err, http.StatusBadRequest, "decode request query failed")
|
2019-06-24 17:36:42 -08:00
|
|
|
}
|
2019-05-16 10:39:25 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
2019-06-25 22:31:54 -08:00
|
|
|
|
|
|
|
// ExtractWhereArgs extracts the sql args from where. This allows requests to accept sql queries for filters and
|
|
|
|
// then replaces the raw values with placeholders. The resulting query will then be executed with bind vars.
|
|
|
|
func ExtractWhereArgs(where string) (string, []interface{}, error) {
|
|
|
|
// Create a full select sql query.
|
|
|
|
query := "select `t` from test where " + where
|
|
|
|
|
|
|
|
// Parse the query.
|
|
|
|
stmt, err := sqlparser.Parse(query)
|
|
|
|
if err != nil {
|
|
|
|
return "", nil, errors.WithMessagef(err, "Failed to parse query - %s", where)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Normalize changes the query statement to use bind values, and updates the bind vars to those values. The
|
|
|
|
// supplied prefix is used to generate the bind var names.
|
|
|
|
bindVars := make(map[string]*querypb.BindVariable)
|
|
|
|
sqlparser.Normalize(stmt, bindVars, "redacted")
|
|
|
|
|
|
|
|
// Loop through all the bind vars and append to the response args list.
|
|
|
|
var vals []interface{}
|
|
|
|
for _, bv := range bindVars {
|
|
|
|
if bv.Values != nil {
|
|
|
|
var l []interface{}
|
|
|
|
for _, v := range bv.Values {
|
|
|
|
l = append(l, string(v.Value))
|
|
|
|
}
|
|
|
|
vals = append(vals, l)
|
|
|
|
} else {
|
|
|
|
vals = append(vals, string(bv.Value))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Update the original query to include the redacted values.
|
|
|
|
query = sqlparser.String(stmt)
|
|
|
|
|
|
|
|
// Parse out the updated where.
|
|
|
|
where = strings.Split(query, " where ")[1]
|
|
|
|
|
|
|
|
return where, vals, nil
|
|
|
|
}
|
2019-06-26 20:21:00 -08:00
|
|
|
|
|
|
|
func RequestIsJson(r *http.Request) bool {
|
|
|
|
if r == nil {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
if v := r.Header.Get("Content-type"); v != "" {
|
|
|
|
for _, hv := range strings.Split(v, ";") {
|
|
|
|
if strings.ToLower(hv) == "application/json" {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if v := r.URL.Query().Get("ResponseFormat"); v != "" {
|
|
|
|
if strings.ToLower(v) == "json" {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if strings.HasSuffix(r.URL.Path, ".json") {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
|
|
|
return false
|
|
|
|
}
|
2019-07-12 11:41:41 -08:00
|
|
|
|
|
|
|
func RequestIsTLS(r *http.Request) bool {
|
|
|
|
return r.TLS != nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func RequestIsWebSocket(r *http.Request) bool {
|
|
|
|
upgrade := r.Header.Get(HeaderUpgrade)
|
|
|
|
return strings.ToLower(upgrade) == "websocket"
|
|
|
|
}
|
|
|
|
|
2019-08-12 21:28:16 -08:00
|
|
|
func RequestIsImage(r *http.Request) bool {
|
|
|
|
accept := r.Header.Get(HeaderAccept)
|
|
|
|
return strings.HasPrefix(accept, "image/")
|
|
|
|
}
|
|
|
|
|
2019-07-12 11:41:41 -08:00
|
|
|
func RequestScheme(r *http.Request) string {
|
|
|
|
// Can't use `r.Request.URL.Scheme`
|
|
|
|
// See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0
|
|
|
|
if RequestIsTLS(r) {
|
|
|
|
return "https"
|
|
|
|
}
|
|
|
|
if scheme := r.Header.Get(HeaderXForwardedProto); scheme != "" {
|
|
|
|
return scheme
|
|
|
|
}
|
|
|
|
if scheme := r.Header.Get(HeaderXForwardedProtocol); scheme != "" {
|
|
|
|
return scheme
|
|
|
|
}
|
|
|
|
if ssl := r.Header.Get(HeaderXForwardedSsl); ssl == "on" {
|
|
|
|
return "https"
|
|
|
|
}
|
|
|
|
if scheme := r.Header.Get(HeaderXUrlScheme); scheme != "" {
|
|
|
|
return scheme
|
|
|
|
}
|
|
|
|
return "http"
|
|
|
|
}
|
|
|
|
|
|
|
|
func RequestRealIP(r *http.Request) string {
|
|
|
|
if ip := r.Header.Get(HeaderXForwardedFor); ip != "" {
|
|
|
|
return strings.Split(ip, ", ")[0]
|
|
|
|
}
|
|
|
|
if ip := r.Header.Get(HeaderXRealIP); ip != "" {
|
|
|
|
return ip
|
|
|
|
}
|
|
|
|
ra, _, _ := net.SplitHostPort(r.RemoteAddr)
|
|
|
|
return ra
|
|
|
|
}
|