package web import ( "context" "encoding/json" "net" "net/http" "strings" "geeks-accelerator/oss/saas-starter-kit/internal/platform/web/weberror" "github.com/gorilla/schema" "github.com/pkg/errors" "github.com/xwb1989/sqlparser" "github.com/xwb1989/sqlparser/dependency/querypb" ) // Headers const ( HeaderAccept = "Accept" HeaderUpgrade = "Upgrade" HeaderXForwardedFor = "X-Forwarded-For" HeaderXForwardedProto = "X-Forwarded-Proto" HeaderXForwardedProtocol = "X-Forwarded-Protocol" HeaderXForwardedSsl = "X-Forwarded-Ssl" HeaderXUrlScheme = "X-Url-Scheme" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" HeaderXRealIP = "X-Real-IP" HeaderXRequestID = "X-Request-ID" HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" HeaderOrigin = "Origin" ) // Decode reads the body of an HTTP request looking for a JSON document. The // body is decoded into the provided value. // // If the provided value is a struct then it is checked for validation tags. func Decode(ctx context.Context, r *http.Request, val interface{}) error { if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch || r.Method == http.MethodDelete { decoder := json.NewDecoder(r.Body) decoder.DisallowUnknownFields() if err := decoder.Decode(val); err != nil { return weberror.NewErrorMessage(ctx, err, http.StatusBadRequest, "decode request body failed") } } else { decoder := schema.NewDecoder() if err := decoder.Decode(val, r.URL.Query()); err != nil { err = errors.Wrap(err, "decode request query failed") return weberror.NewErrorMessage(ctx, err, http.StatusBadRequest, "decode request query failed") } } return nil } // ExtractWhereArgs extracts the sql args from where. This allows requests to accept sql queries for filters and // then replaces the raw values with placeholders. The resulting query will then be executed with bind vars. func ExtractWhereArgs(where string) (string, []interface{}, error) { // Create a full select sql query. query := "select `t` from test where " + where // Parse the query. stmt, err := sqlparser.Parse(query) if err != nil { return "", nil, errors.WithMessagef(err, "Failed to parse query - %s", where) } // Normalize changes the query statement to use bind values, and updates the bind vars to those values. The // supplied prefix is used to generate the bind var names. bindVars := make(map[string]*querypb.BindVariable) sqlparser.Normalize(stmt, bindVars, "redacted") // Loop through all the bind vars and append to the response args list. var vals []interface{} for _, bv := range bindVars { if bv.Values != nil { var l []interface{} for _, v := range bv.Values { l = append(l, string(v.Value)) } vals = append(vals, l) } else { vals = append(vals, string(bv.Value)) } } // Update the original query to include the redacted values. query = sqlparser.String(stmt) // Parse out the updated where. where = strings.Split(query, " where ")[1] return where, vals, nil } func RequestIsJson(r *http.Request) bool { if r == nil { return false } if v := r.Header.Get("Content-type"); v != "" { for _, hv := range strings.Split(v, ";") { if strings.ToLower(hv) == "application/json" { return true } } } if v := r.URL.Query().Get("ResponseFormat"); v != "" { if strings.ToLower(v) == "json" { return true } } if strings.HasSuffix(r.URL.Path, ".json") { return true } return false } func RequestIsTLS(r *http.Request) bool { return r.TLS != nil } func RequestIsWebSocket(r *http.Request) bool { upgrade := r.Header.Get(HeaderUpgrade) return strings.ToLower(upgrade) == "websocket" } func RequestIsImage(r *http.Request) bool { accept := r.Header.Get(HeaderAccept) return strings.HasPrefix(accept, "image/") } func RequestScheme(r *http.Request) string { // Can't use `r.Request.URL.Scheme` // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 if RequestIsTLS(r) { return "https" } if scheme := r.Header.Get(HeaderXForwardedProto); scheme != "" { return scheme } if scheme := r.Header.Get(HeaderXForwardedProtocol); scheme != "" { return scheme } if ssl := r.Header.Get(HeaderXForwardedSsl); ssl == "on" { return "https" } if scheme := r.Header.Get(HeaderXUrlScheme); scheme != "" { return scheme } return "http" } func RequestRealIP(r *http.Request) string { if ip := r.Header.Get(HeaderXForwardedFor); ip != "" { return strings.Split(ip, ", ")[0] } if ip := r.Header.Get(HeaderXRealIP); ip != "" { return ip } ra, _, _ := net.SplitHostPort(r.RemoteAddr) return ra }