1
0
mirror of https://github.com/go-micro/go-micro.git synced 2025-11-23 21:44:41 +02:00

fix: easy lint fixes to api/ (#2567)

This commit is contained in:
Rene Jochum
2022-10-01 10:50:11 +02:00
committed by GitHub
parent 010b1d9f11
commit 065f9714e9
23 changed files with 380 additions and 165 deletions

View File

@@ -164,7 +164,7 @@ func (client *Client) Call(service, endpoint string, request, response interface
// Stream enables the ability to stream via websockets. // Stream enables the ability to stream via websockets.
func (client *Client) Stream(service, endpoint string, request interface{}) (*Stream, error) { func (client *Client) Stream(service, endpoint string, request interface{}) (*Stream, error) {
b, err := marshalRequest(endpoint, request) bytes, err := marshalRequest(endpoint, request)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -196,7 +196,7 @@ func (client *Client) Stream(service, endpoint string, request interface{}) (*St
} }
// send the first request // send the first request
if err := conn.WriteMessage(websocket.TextMessage, b); err != nil { if err := conn.WriteMessage(websocket.TextMessage, bytes); err != nil {
return nil, err return nil, err
} }
@@ -224,7 +224,7 @@ func (s *Stream) Send(v interface{}) error {
return s.conn.WriteMessage(websocket.TextMessage, b) return s.conn.WriteMessage(websocket.TextMessage, b)
} }
func marshalRequest(endpoint string, v interface{}) ([]byte, error) { func marshalRequest(_ string, v interface{}) ([]byte, error) {
return json.Marshal(v) return json.Marshal(v)
} }

View File

@@ -18,6 +18,7 @@ type apiHandler struct {
} }
const ( const (
// Handler is the name of the Handler.
Handler = "api" Handler = "api"
) )
@@ -29,12 +30,15 @@ func (a *apiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
r.Body = http.MaxBytesReader(w, r.Body, bsize) r.Body = http.MaxBytesReader(w, r.Body, bsize)
request, err := requestToProto(r) request, err := requestToProto(r)
if err != nil { if err != nil {
er := errors.InternalServerError("go.micro.api", err.Error()) er := errors.InternalServerError("go.micro.api", err.Error())
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500) w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(er.Error())) w.Write([]byte(er.Error()))
return return
} }
@@ -45,18 +49,23 @@ func (a *apiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s, err := a.opts.Router.Route(r) s, err := a.opts.Router.Route(r)
if err != nil { if err != nil {
er := errors.InternalServerError("go.micro.api", err.Error()) er := errors.InternalServerError("go.micro.api", err.Error())
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500) w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(er.Error())) w.Write([]byte(er.Error()))
return return
} }
service = s service = s
} else { } else {
// we have no way of routing the request // we have no way of routing the request
er := errors.InternalServerError("go.micro.api", "no route found") er := errors.InternalServerError("go.micro.api", "no route found")
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500) w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(er.Error())) w.Write([]byte(er.Error()))
return return
} }
@@ -72,14 +81,17 @@ func (a *apiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := c.Call(cx, req, rsp, client.WithSelectOption(so)); err != nil { if err := c.Call(cx, req, rsp, client.WithSelectOption(so)); err != nil {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
ce := errors.Parse(err.Error()) ce := errors.Parse(err.Error())
switch ce.Code { switch ce.Code {
case 0: case 0:
w.WriteHeader(500) w.WriteHeader(http.StatusInternalServerError)
default: default:
w.WriteHeader(int(ce.Code)) w.WriteHeader(int(ce.Code))
} }
w.Write([]byte(ce.Error())) w.Write([]byte(ce.Error()))
return return
} else if rsp.StatusCode == 0 { } else if rsp.StatusCode == 0 {
rsp.StatusCode = http.StatusOK rsp.StatusCode = http.StatusOK
@@ -96,6 +108,7 @@ func (a *apiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
w.WriteHeader(int(rsp.StatusCode)) w.WriteHeader(int(rsp.StatusCode))
w.Write([]byte(rsp.Body)) w.Write([]byte(rsp.Body))
} }
@@ -103,8 +116,10 @@ func (a *apiHandler) String() string {
return "api" return "api"
} }
// NewHandler returns an api.Handler.
func NewHandler(opts ...handler.Option) handler.Handler { func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.NewOptions(opts...) options := handler.NewOptions(opts...)
return &apiHandler{ return &apiHandler{
opts: options, opts: options,
} }

View File

@@ -20,7 +20,7 @@ var (
func requestToProto(r *http.Request) (*api.Request, error) { func requestToProto(r *http.Request) (*api.Request, error) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
return nil, fmt.Errorf("Error parsing form: %v", err) return nil, fmt.Errorf("Error parsing form: %w", err)
} }
req := &api.Request{ req := &api.Request{
@@ -46,9 +46,11 @@ func requestToProto(r *http.Request) (*api.Request, error) {
default: default:
buf := bufferPool.Get() buf := bufferPool.Get()
defer bufferPool.Put(buf) defer bufferPool.Put(buf)
if _, err = buf.ReadFrom(r.Body); err != nil { if _, err = buf.ReadFrom(r.Body); err != nil {
return nil, err return nil, err
} }
req.Body = buf.String() req.Body = buf.String()
} }
} }
@@ -81,6 +83,7 @@ func requestToProto(r *http.Request) (*api.Request, error) {
} }
req.Get[key] = header req.Get[key] = header
} }
header.Values = vals header.Values = vals
} }
@@ -93,6 +96,7 @@ func requestToProto(r *http.Request) (*api.Request, error) {
} }
req.Post[key] = header req.Post[key] = header
} }
header.Values = vals header.Values = vals
} }
@@ -104,6 +108,7 @@ func requestToProto(r *http.Request) (*api.Request, error) {
} }
req.Header[key] = header req.Header[key] = header
} }
header.Values = vals header.Values = vals
} }

View File

@@ -26,6 +26,7 @@ type event struct {
} }
var ( var (
// Handler is the name of this handler.
Handler = "event" Handler = "event"
versionRe = regexp.MustCompilePOSIX("^v[0-9]+$") versionRe = regexp.MustCompilePOSIX("^v[0-9]+$")
) )
@@ -34,55 +35,56 @@ func eventName(parts []string) string {
return strings.Join(parts, ".") return strings.Join(parts, ".")
} }
func evRoute(ns, p string) (string, string) { func evRoute(namespace, myPath string) (string, string) {
p = path.Clean(p) myPath = path.Clean(myPath)
p = strings.TrimPrefix(p, "/") myPath = strings.TrimPrefix(myPath, "/")
if len(p) == 0 { if len(myPath) == 0 {
return ns, "event" return namespace, Handler
} }
parts := strings.Split(p, "/") parts := strings.Split(myPath, "/")
// no path // no path
if len(parts) == 0 { if len(parts) == 0 {
// topic: namespace // topic: namespace
// action: event // action: event
return strings.Trim(ns, "."), "event" return strings.Trim(namespace, "."), Handler
} }
// Treat /v[0-9]+ as versioning // Treat /v[0-9]+ as versioning
// /v1/foo/bar => topic: v1.foo action: bar // /v1/foo/bar => topic: v1.foo action: bar
if len(parts) >= 2 && versionRe.Match([]byte(parts[0])) { if len(parts) >= 2 && versionRe.Match([]byte(parts[0])) {
topic := ns + "." + strings.Join(parts[:2], ".") topic := namespace + "." + strings.Join(parts[:2], ".")
action := eventName(parts[1:]) action := eventName(parts[1:])
return topic, action return topic, action
} }
// /foo => topic: ns.foo action: foo // /foo => topic: ns.foo action: foo
// /foo/bar => topic: ns.foo action: bar // /foo/bar => topic: ns.foo action: bar
topic := ns + "." + strings.Join(parts[:1], ".") topic := namespace + "." + strings.Join(parts[:1], ".")
action := eventName(parts[1:]) action := eventName(parts[1:])
return topic, action return topic, action
} }
func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (e *event) ServeHTTP(rsp http.ResponseWriter, req *http.Request) {
bsize := handler.DefaultMaxRecvSize bsize := handler.DefaultMaxRecvSize
if e.opts.MaxRecvSize > 0 { if e.opts.MaxRecvSize > 0 {
bsize = e.opts.MaxRecvSize bsize = e.opts.MaxRecvSize
} }
r.Body = http.MaxBytesReader(w, r.Body, bsize) req.Body = http.MaxBytesReader(rsp, req.Body, bsize)
// request to topic:event // request to topic:event
// create event // create event
// publish to topic // publish to topic
topic, action := evRoute(e.opts.Namespace, r.URL.Path) topic, action := evRoute(e.opts.Namespace, req.URL.Path)
// create event // create event
ev := &proto.Event{ event := &proto.Event{
Name: action, Name: action,
// TODO: dedupe event // TODO: dedupe event
Id: fmt.Sprintf("%s-%s-%s", topic, action, uuid.New().String()), Id: fmt.Sprintf("%s-%s-%s", topic, action, uuid.New().String()),
@@ -91,49 +93,53 @@ func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// set headers // set headers
for key, vals := range r.Header { for key, vals := range req.Header {
header, ok := ev.Header[key] header, ok := event.Header[key]
if !ok { if !ok {
header = &proto.Pair{ header = &proto.Pair{
Key: key, Key: key,
} }
ev.Header[key] = header event.Header[key] = header
} }
header.Values = vals header.Values = vals
} }
// set body // set body
if r.Method == "GET" { if req.Method == http.MethodGet {
bytes, _ := json.Marshal(r.URL.Query()) bytes, _ := json.Marshal(req.URL.Query())
ev.Data = string(bytes) event.Data = string(bytes)
} else { } else {
// Read body // Read body
buf := bufferPool.Get() buf := bufferPool.Get()
defer bufferPool.Put(buf) defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil { if _, err := buf.ReadFrom(req.Body); err != nil {
http.Error(w, err.Error(), 500) http.Error(rsp, err.Error(), http.StatusInternalServerError)
return return
} }
ev.Data = buf.String() event.Data = buf.String()
} }
// get client // get client
c := e.opts.Client c := e.opts.Client
// create publication // create publication
p := c.NewMessage(topic, ev) p := c.NewMessage(topic, event)
// publish event // publish event
if err := c.Publish(ctx.FromRequest(r), p); err != nil { if err := c.Publish(ctx.FromRequest(req), p); err != nil {
http.Error(w, err.Error(), 500) http.Error(rsp, err.Error(), http.StatusInternalServerError)
return return
} }
} }
func (e *event) String() string { func (e *event) String() string {
return "event" return Handler
} }
// NewHandler returns a new event handler.
func NewHandler(opts ...handler.Option) handler.Handler { func NewHandler(opts ...handler.Option) handler.Handler {
return &event{ return &event{
opts: handler.NewOptions(opts...), opts: handler.NewOptions(opts...),

View File

@@ -14,6 +14,7 @@ import (
) )
const ( const (
// Handler is the name of the handler.
Handler = "http" Handler = "http"
) )
@@ -24,18 +25,18 @@ type httpHandler struct {
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
service, err := h.getService(r) service, err := h.getService(r)
if err != nil { if err != nil {
w.WriteHeader(500) w.WriteHeader(http.StatusInternalServerError)
return return
} }
if len(service) == 0 { if len(service) == 0 {
w.WriteHeader(404) w.WriteHeader(http.StatusNotFound)
return return
} }
rp, err := url.Parse(service) rp, err := url.Parse(service)
if err != nil { if err != nil {
w.WriteHeader(500) w.WriteHeader(http.StatusInternalServerError)
return return
} }
@@ -52,6 +53,7 @@ func (h *httpHandler) getService(r *http.Request) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
service = s service = s
} else { } else {
// we have no way of routing the request // we have no way of routing the request

View File

@@ -7,9 +7,11 @@ import (
) )
var ( var (
DefaultMaxRecvSize int64 = 1024 * 1024 * 100 // 10Mb // DefaultMaxRecvSize is 10MiB.
DefaultMaxRecvSize int64 = 1024 * 1024 * 100
) )
// Options is the list of api Options.
type Options struct { type Options struct {
MaxRecvSize int64 MaxRecvSize int64
Namespace string Namespace string
@@ -18,6 +20,7 @@ type Options struct {
Logger logger.Logger Logger logger.Logger
} }
// Option is a api Option.
type Option func(o *Options) type Option func(o *Options)
// NewOptions fills in the blanks. // NewOptions fills in the blanks.
@@ -59,6 +62,7 @@ func WithRouter(r router.Router) Option {
} }
} }
// WithClient sets the client for the handler.
func WithClient(c client.Client) Option { func WithClient(c client.Client) Option {
return func(o *Options) { return func(o *Options) {
o.Client = c o.Client = c

View File

@@ -28,6 +28,7 @@ import (
) )
const ( const (
// Handler is the name of this handler.
Handler = "rpc" Handler = "rpc"
packageID = "go.micro.api" packageID = "go.micro.api"
) )
@@ -76,6 +77,7 @@ func strategy(services []*registry.Service) selector.Strategy {
func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
logger := h.opts.Logger logger := h.opts.Logger
bsize := handler.DefaultMaxRecvSize bsize := handler.DefaultMaxRecvSize
if h.opts.MaxRecvSize > 0 { if h.opts.MaxRecvSize > 0 {
bsize = h.opts.MaxRecvSize bsize = h.opts.MaxRecvSize
} }
@@ -83,6 +85,7 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, bsize) r.Body = http.MaxBytesReader(w, r.Body, bsize)
defer r.Body.Close() defer r.Body.Close()
var service *router.Route var service *router.Route
if h.opts.Router != nil { if h.opts.Router != nil {
@@ -93,8 +96,10 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if werr != nil { if werr != nil {
logger.Log(log.ErrorLevel, werr) logger.Log(log.ErrorLevel, werr)
} }
return return
} }
service = s service = s
} else { } else {
// we have no way of routing the request // we have no way of routing the request
@@ -105,18 +110,18 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
ct := r.Header.Get("Content-Type") contentType := r.Header.Get("Content-Type")
// Strip charset from Content-Type (like `application/json; charset=UTF-8`) // Strip charset from Content-Type (like `application/json; charset=UTF-8`)
if idx := strings.IndexRune(ct, ';'); idx >= 0 { if idx := strings.IndexRune(contentType, ';'); idx >= 0 {
ct = ct[:idx] contentType = contentType[:idx]
} }
// micro client // micro client
c := h.opts.Client myClient := h.opts.Client
// create context // create context
cx := ctx.FromRequest(r) myContext := ctx.FromRequest(r)
// get context from http handler wrappers // get context from http handler wrappers
md, ok := metadata.FromContext(r.Context()) md, ok := metadata.FromContext(r.Context())
if !ok { if !ok {
@@ -132,23 +137,24 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// merge context with overwrite // merge context with overwrite
cx = metadata.MergeContext(cx, md, true) myContext = metadata.MergeContext(myContext, md, true)
// set merged context to request // set merged context to request
*r = *r.Clone(cx) *r = *r.Clone(myContext)
// if stream we currently only support json // if stream we currently only support json
if isStream(r, service) { if isStream(r, service) {
// drop older context as it can have timeouts and create new // drop older context as it can have timeouts and create new
// md, _ := metadata.FromContext(cx) // md, _ := metadata.FromContext(cx)
// serveWebsocket(context.TODO(), w, r, service, c) // serveWebsocket(context.TODO(), w, r, service, c)
if err := serveWebsocket(cx, w, r, service, c); err != nil { if err := serveWebsocket(myContext, w, r, service, myClient); err != nil {
logger.Log(log.ErrorLevel, err) logger.Log(log.ErrorLevel, err)
} }
return return
} }
// create strategy // create strategy
so := selector.WithStrategy(strategy(service.Versions)) mySelector := selector.WithStrategy(strategy(service.Versions))
// walk the standard call path // walk the standard call path
// get payload // get payload
@@ -157,6 +163,7 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if werr := writeError(w, r, err); werr != nil { if werr := writeError(w, r, err); werr != nil {
logger.Log(log.ErrorLevel, werr) logger.Log(log.ErrorLevel, werr)
} }
return return
} }
@@ -164,7 +171,7 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch { switch {
// proto codecs // proto codecs
case hasCodec(ct, protoCodecs): case hasCodec(contentType, protoCodecs):
request := &proto.Message{} request := &proto.Message{}
// if the extracted payload isn't empty lets use it // if the extracted payload isn't empty lets use it
if len(br) > 0 { if len(br) > 0 {
@@ -174,18 +181,19 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// create request/response // create request/response
response := &proto.Message{} response := &proto.Message{}
req := c.NewRequest( req := myClient.NewRequest(
service.Service, service.Service,
service.Endpoint.Name, service.Endpoint.Name,
request, request,
client.WithContentType(ct), client.WithContentType(contentType),
) )
// make the call // make the call
if err := c.Call(cx, req, response, client.WithSelectOption(so)); err != nil { if err := myClient.Call(myContext, req, response, client.WithSelectOption(mySelector)); err != nil {
if werr := writeError(w, r, err); werr != nil { if werr := writeError(w, r, err); werr != nil {
logger.Log(log.ErrorLevel, werr) logger.Log(log.ErrorLevel, werr)
} }
return return
} }
@@ -195,13 +203,14 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if werr := writeError(w, r, err); werr != nil { if werr := writeError(w, r, err); werr != nil {
logger.Log(log.ErrorLevel, werr) logger.Log(log.ErrorLevel, werr)
} }
return return
} }
default: default:
// if json codec is not present set to json // if json codec is not present set to json
if !hasCodec(ct, jsonCodecs) { if !hasCodec(contentType, jsonCodecs) {
ct = "application/json" contentType = "application/json"
} }
// default to trying json // default to trying json
@@ -214,17 +223,18 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// create request/response // create request/response
var response json.RawMessage var response json.RawMessage
req := c.NewRequest( req := myClient.NewRequest(
service.Service, service.Service,
service.Endpoint.Name, service.Endpoint.Name,
&request, &request,
client.WithContentType(ct), client.WithContentType(contentType),
) )
// make the call // make the call
if err := c.Call(cx, req, &response, client.WithSelectOption(so)); err != nil { if err := myClient.Call(myContext, req, &response, client.WithSelectOption(mySelector)); err != nil {
if werr := writeError(w, r, err); werr != nil { if werr := writeError(w, r, err); werr != nil {
logger.Log(log.ErrorLevel, werr) logger.Log(log.ErrorLevel, werr)
} }
return return
} }
@@ -234,6 +244,7 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if werr := writeError(w, r, err); werr != nil { if werr := writeError(w, r, err); werr != nil {
logger.Log(log.ErrorLevel, werr) logger.Log(log.ErrorLevel, werr)
} }
return return
} }
} }
@@ -244,8 +255,8 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
func (rh *rpcHandler) String() string { func (h *rpcHandler) String() string {
return "rpc" return Handler
} }
func hasCodec(ct string, codecs []string) bool { func hasCodec(ct string, codecs []string) bool {
@@ -254,6 +265,7 @@ func hasCodec(ct string, codecs []string) bool {
return true return true
} }
} }
return false return false
} }
@@ -266,37 +278,44 @@ func requestPayload(r *http.Request) ([]byte, error) {
// we have to decode json-rpc and proto-rpc because we suck // we have to decode json-rpc and proto-rpc because we suck
// well actually because there's no proxy codec right now // well actually because there's no proxy codec right now
ct := r.Header.Get("Content-Type") myCt := r.Header.Get("Content-Type")
switch { switch {
case strings.Contains(ct, "application/json-rpc"): case strings.Contains(myCt, "application/json-rpc"):
msg := codec.Message{ msg := codec.Message{
Type: codec.Request, Type: codec.Request,
Header: make(map[string]string), Header: make(map[string]string),
} }
c := jsonrpc.NewCodec(&buffer{r.Body}) c := jsonrpc.NewCodec(&buffer{r.Body})
if err = c.ReadHeader(&msg, codec.Request); err != nil { if err = c.ReadHeader(&msg, codec.Request); err != nil {
return nil, err return nil, err
} }
var raw json.RawMessage var raw json.RawMessage
if err = c.ReadBody(&raw); err != nil { if err = c.ReadBody(&raw); err != nil {
return nil, err return nil, err
} }
return ([]byte)(raw), nil return ([]byte)(raw), nil
case strings.Contains(ct, "application/proto-rpc"), strings.Contains(ct, "application/octet-stream"): case strings.Contains(myCt, "application/proto-rpc"), strings.Contains(myCt, "application/octet-stream"):
msg := codec.Message{ msg := codec.Message{
Type: codec.Request, Type: codec.Request,
Header: make(map[string]string), Header: make(map[string]string),
} }
c := protorpc.NewCodec(&buffer{r.Body}) c := protorpc.NewCodec(&buffer{r.Body})
if err = c.ReadHeader(&msg, codec.Request); err != nil { if err = c.ReadHeader(&msg, codec.Request); err != nil {
return nil, err return nil, err
} }
var raw proto.Message var raw proto.Message
if err = c.ReadBody(&raw); err != nil { if err = c.ReadBody(&raw); err != nil {
return nil, err return nil, err
} }
return raw.Marshal() return raw.Marshal()
case strings.Contains(ct, "application/www-x-form-urlencoded"): case strings.Contains(myCt, "application/www-x-form-urlencoded"):
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
return nil, err return nil, err
} }
@@ -314,6 +333,7 @@ func requestPayload(r *http.Request) ([]byte, error) {
// otherwise as per usual // otherwise as per usual
ctx := r.Context() ctx := r.Context()
// dont user meadata.FromContext as it mangles names // dont user meadata.FromContext as it mangles names
md, ok := metadata.FromContext(ctx) md, ok := metadata.FromContext(ctx)
if !ok { if !ok {
@@ -330,9 +350,11 @@ func requestPayload(r *http.Request) ([]byte, error) {
// filter own keys // filter own keys
if strings.HasPrefix(k, "x-api-field-") { if strings.HasPrefix(k, "x-api-field-") {
matches[strings.TrimPrefix(k, "x-api-field-")] = v matches[strings.TrimPrefix(k, "x-api-field-")] = v
delete(md, k) delete(md, k)
} else if k == "x-api-body" { } else if k == "x-api-body" {
bodydst = v bodydst = v
delete(md, k) delete(md, k)
} }
} }
@@ -343,10 +365,12 @@ func requestPayload(r *http.Request) ([]byte, error) {
// get fields from url values // get fields from url values
if len(r.URL.RawQuery) > 0 { if len(r.URL.RawQuery) > 0 {
umd := make(map[string]interface{}) umd := make(map[string]interface{})
err = qson.Unmarshal(&umd, r.URL.RawQuery) err = qson.Unmarshal(&umd, r.URL.RawQuery)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for k, v := range umd { for k, v := range umd {
matches[k] = v matches[k] = v
} }
@@ -361,24 +385,29 @@ func requestPayload(r *http.Request) ([]byte, error) {
req[k] = v req[k] = v
continue continue
} }
em := make(map[string]interface{}) em := make(map[string]interface{})
em[ps[len(ps)-1]] = v em[ps[len(ps)-1]] = v
for i := len(ps) - 2; i > 0; i-- { for i := len(ps) - 2; i > 0; i-- {
nm := make(map[string]interface{}) nm := make(map[string]interface{})
nm[ps[i]] = em nm[ps[i]] = em
em = nm em = nm
} }
if vm, ok := req[ps[0]]; ok { if vm, ok := req[ps[0]]; ok {
// nested map // nested map
nm := vm.(map[string]interface{}) nm := vm.(map[string]interface{})
for vk, vv := range em { for vk, vv := range em {
nm[vk] = vv nm[vk] = vv
} }
req[ps[0]] = nm req[ps[0]] = nm
} else { } else {
req[ps[0]] = em req[ps[0]] = em
} }
} }
pathbuf := []byte("{}") pathbuf := []byte("{}")
if len(req) > 0 { if len(req) > 0 {
pathbuf, err = json.Marshal(req) pathbuf, err = json.Marshal(req)
@@ -388,41 +417,48 @@ func requestPayload(r *http.Request) ([]byte, error) {
} }
urlbuf := []byte("{}") urlbuf := []byte("{}")
out, err := jsonpatch.MergeMergePatches(urlbuf, pathbuf) out, err := jsonpatch.MergeMergePatches(urlbuf, pathbuf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch r.Method { switch r.Method {
case "GET": case http.MethodGet:
// empty response // empty response
if strings.Contains(ct, "application/json") && string(out) == "{}" { if strings.Contains(myCt, "application/json") && string(out) == "{}" {
return out, nil return out, nil
} else if string(out) == "{}" && !strings.Contains(ct, "application/json") { } else if string(out) == "{}" && !strings.Contains(myCt, "application/json") {
return []byte{}, nil return []byte{}, nil
} }
return out, nil return out, nil
case "PATCH", "POST", "PUT", "DELETE": case http.MethodPatch, http.MethodPost, http.MethodPut, http.MethodDelete:
bodybuf := []byte("{}") bodybuf := []byte("{}")
buf := bufferPool.Get() buf := bufferPool.Get()
defer bufferPool.Put(buf) defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil { if _, err := buf.ReadFrom(r.Body); err != nil {
return nil, err return nil, err
} }
if b := buf.Bytes(); len(b) > 0 { if b := buf.Bytes(); len(b) > 0 {
bodybuf = b bodybuf = b
} }
if bodydst == "" || bodydst == "*" { if bodydst == "" || bodydst == "*" {
if out, err = jsonpatch.MergeMergePatches(out, bodybuf); err == nil { if out, err = jsonpatch.MergeMergePatches(out, bodybuf); err == nil {
return out, nil return out, nil
} }
} }
var jsonbody map[string]interface{} var jsonbody map[string]interface{}
if json.Valid(bodybuf) { if json.Valid(bodybuf) {
if err = json.Unmarshal(bodybuf, &jsonbody); err != nil { if err = json.Unmarshal(bodybuf, &jsonbody); err != nil {
return nil, err return nil, err
} }
} }
dstmap := make(map[string]interface{}) dstmap := make(map[string]interface{})
ps := strings.Split(bodydst, ".") ps := strings.Split(bodydst, ".")
if len(ps) == 1 { if len(ps) == 1 {
@@ -463,33 +499,35 @@ func requestPayload(r *http.Request) ([]byte, error) {
return []byte{}, nil return []byte{}, nil
} }
func writeError(w http.ResponseWriter, r *http.Request, err error) error { func writeError(rsp http.ResponseWriter, req *http.Request, err error) error {
ce := errors.Parse(err.Error()) ce := errors.Parse(err.Error())
switch ce.Code { switch ce.Code {
case 0: case 0:
// assuming it's totally screwed // assuming it's totally screwed
ce.Code = 500 ce.Code = http.StatusInternalServerError
ce.Id = packageID ce.Id = packageID
ce.Status = http.StatusText(500) ce.Status = http.StatusText(http.StatusInternalServerError)
ce.Detail = "error during request: " + ce.Detail ce.Detail = "error during request: " + ce.Detail
w.WriteHeader(500)
rsp.WriteHeader(http.StatusInternalServerError)
default: default:
w.WriteHeader(int(ce.Code)) rsp.WriteHeader(int(ce.Code))
} }
// response content type // response content type
w.Header().Set("Content-Type", "application/json") rsp.Header().Set("Content-Type", "application/json")
// Set trailers // Set trailers
if strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { if strings.Contains(req.Header.Get("Content-Type"), "application/grpc") {
w.Header().Set("Trailer", "grpc-status") rsp.Header().Set("Trailer", "grpc-status")
w.Header().Set("Trailer", "grpc-message") rsp.Header().Set("Trailer", "grpc-message")
w.Header().Set("grpc-status", "13") rsp.Header().Set("grpc-status", "13")
w.Header().Set("grpc-message", ce.Detail) rsp.Header().Set("grpc-message", ce.Detail)
} }
_, werr := w.Write([]byte(ce.Error())) _, werr := rsp.Write([]byte(ce.Error()))
return werr return werr
} }
@@ -512,11 +550,14 @@ func writeResponse(w http.ResponseWriter, r *http.Request, rsp []byte) error {
// write response // write response
_, err := w.Write(rsp) _, err := w.Write(rsp)
return err return err
} }
// NewHandler returns a new RPC handler.
func NewHandler(opts ...handler.Option) handler.Handler { func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.NewOptions(opts...) options := handler.NewOptions(opts...)
return &rpcHandler{ return &rpcHandler{
opts: options, opts: options,
} }

View File

@@ -20,32 +20,33 @@ import (
// serveWebsocket will stream rpc back over websockets assuming json. // serveWebsocket will stream rpc back over websockets assuming json.
func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *router.Route, c client.Client) (err error) { func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *router.Route, c client.Client) (err error) {
var op ws.OpCode var opCode ws.OpCode
ct := r.Header.Get("Content-Type") myCt := r.Header.Get("Content-Type")
// Strip charset from Content-Type (like `application/json; charset=UTF-8`) // Strip charset from Content-Type (like `application/json; charset=UTF-8`)
if idx := strings.IndexRune(ct, ';'); idx >= 0 { if idx := strings.IndexRune(myCt, ';'); idx >= 0 {
ct = ct[:idx] myCt = myCt[:idx]
} }
// check proto from request // check proto from request
switch ct { switch myCt {
case "application/json": case "application/json":
op = ws.OpText opCode = ws.OpText
default: default:
op = ws.OpBinary opCode = ws.OpBinary
} }
hdr := make(http.Header) hdr := make(http.Header)
if proto, ok := r.Header["Sec-Websocket-Protocol"]; ok { if proto, ok := r.Header["Sec-Websocket-Protocol"]; ok {
for _, p := range proto { for _, p := range proto {
switch p { if p == "binary" {
case "binary":
hdr["Sec-WebSocket-Protocol"] = []string{"binary"} hdr["Sec-WebSocket-Protocol"] = []string{"binary"}
op = ws.OpBinary opCode = ws.OpBinary
} }
} }
} }
payload, err := requestPayload(r) payload, err := requestPayload(r)
if err != nil { if err != nil {
return return
@@ -66,7 +67,7 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
Header: hdr, Header: hdr,
} }
conn, rw, _, err := upgrader.Upgrade(r, w) conn, uRw, _, err := upgrader.Upgrade(r, w)
if err != nil { if err != nil {
return return
} }
@@ -78,8 +79,9 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
}() }()
var request interface{} var request interface{}
if !bytes.Equal(payload, []byte(`{}`)) { if !bytes.Equal(payload, []byte(`{}`)) {
switch ct { switch myCt {
case "application/json", "": case "application/json", "":
m := json.RawMessage(payload) m := json.RawMessage(payload)
request = &m request = &m
@@ -89,14 +91,15 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
} }
// we always need to set content type for message // we always need to set content type for message
if ct == "" { if myCt == "" {
ct = "application/json" myCt = "application/json"
} }
req := c.NewRequest( req := c.NewRequest(
service.Service, service.Service,
service.Endpoint.Name, service.Endpoint.Name,
request, request,
client.WithContentType(ct), client.WithContentType(myCt),
client.StreamingRequest(), client.StreamingRequest(),
) )
@@ -114,7 +117,7 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
} }
go func() { go func() {
if wErr := writeLoop(rw, stream); wErr != nil && err == nil { if wErr := writeLoop(uRw, stream); wErr != nil && err == nil {
err = wErr err = wErr
} }
}() }()
@@ -136,14 +139,16 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
if strings.Contains(err.Error(), "context canceled") { if strings.Contains(err.Error(), "context canceled") {
return nil return nil
} }
return err return err
} }
// write the response // write the response
if err = wsutil.WriteServerMessage(rw, op, buf); err != nil { if err = wsutil.WriteServerMessage(uRw, opCode, buf); err != nil {
return err return err
} }
if err = rw.Flush(); err != nil {
if err = uRw.Flush(); err != nil {
return err return err
} }
} }
@@ -170,10 +175,12 @@ func writeLoop(rw io.ReadWriter, stream client.Stream) error {
case ws.StatusNormalClosure, ws.StatusNoStatusRcvd: case ws.StatusNormalClosure, ws.StatusNoStatusRcvd:
// this happens when user close ws connection, or we don't get any status // this happens when user close ws connection, or we don't get any status
return nil return nil
default:
return err
} }
} }
return err
} }
switch op { switch op {
default: default:
// not relevant // not relevant
@@ -210,6 +217,7 @@ func isStream(r *http.Request, srv *router.Route) bool {
} }
} }
} }
return false return false
} }
@@ -221,6 +229,7 @@ func isWebSocket(r *http.Request) bool {
return true return true
} }
} }
return false return false
} }

View File

@@ -17,6 +17,7 @@ import (
) )
const ( const (
// Handler is the name of the handler.
Handler = "web" Handler = "web"
) )
@@ -27,18 +28,18 @@ type webHandler struct {
func (wh *webHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (wh *webHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
service, err := wh.getService(r) service, err := wh.getService(r)
if err != nil { if err != nil {
w.WriteHeader(500) w.WriteHeader(http.StatusInternalServerError)
return return
} }
if len(service) == 0 { if len(service) == 0 {
w.WriteHeader(404) w.WriteHeader(http.StatusNotFound)
return return
} }
rp, err := url.Parse(service) rp, err := url.Parse(service)
if err != nil { if err != nil {
w.WriteHeader(500) w.WriteHeader(http.StatusInternalServerError)
return return
} }
@@ -60,6 +61,7 @@ func (wh *webHandler) getService(r *http.Request) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
service = s service = s
} else { } else {
// we have no way of routing the request // we have no way of routing the request
@@ -79,12 +81,12 @@ func (wh *webHandler) getService(r *http.Request) (string, error) {
} }
// serveWebSocket used to serve a web socket proxied connection. // serveWebSocket used to serve a web socket proxied connection.
func (wh *webHandler) serveWebSocket(host string, w http.ResponseWriter, r *http.Request) { func (wh *webHandler) serveWebSocket(host string, rsp http.ResponseWriter, r *http.Request) {
req := new(http.Request) req := new(http.Request)
*req = *r *req = *r
if len(host) == 0 { if len(host) == 0 {
http.Error(w, "invalid host", 500) http.Error(rsp, "invalid host", http.StatusInternalServerError)
return return
} }
@@ -93,20 +95,21 @@ func (wh *webHandler) serveWebSocket(host string, w http.ResponseWriter, r *http
if ips, ok := req.Header["X-Forwarded-For"]; ok { if ips, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(ips, ", ") + ", " + clientIP clientIP = strings.Join(ips, ", ") + ", " + clientIP
} }
req.Header.Set("X-Forwarded-For", clientIP) req.Header.Set("X-Forwarded-For", clientIP)
} }
// connect to the backend host // connect to the backend host
conn, err := net.Dial("tcp", host) conn, err := net.Dial("tcp", host)
if err != nil { if err != nil {
http.Error(w, err.Error(), 500) http.Error(rsp, err.Error(), http.StatusInternalServerError)
return return
} }
// hijack the connection // hijack the connection
hj, ok := w.(http.Hijacker) hj, ok := rsp.(http.Hijacker)
if !ok { if !ok {
http.Error(w, "failed to connect", 500) http.Error(rsp, "failed to connect", http.StatusInternalServerError)
return return
} }
@@ -143,6 +146,7 @@ func isWebSocket(r *http.Request) bool {
return true return true
} }
} }
return false return false
} }
@@ -157,6 +161,7 @@ func (wh *webHandler) String() string {
return "web" return "web"
} }
// NewHandler returns a new web handler.
func NewHandler(opts ...handler.Option) handler.Handler { func NewHandler(opts ...handler.Option) handler.Handler {
return &webHandler{ return &webHandler{
opts: handler.NewOptions(opts...), opts: handler.NewOptions(opts...),

View File

@@ -9,8 +9,10 @@ import (
"go-micro.dev/v4/api/resolver" "go-micro.dev/v4/api/resolver"
) )
// Resolver is the gRPC Resolver.
type Resolver struct{} type Resolver struct{}
// Resolve resolves a http.Request to an grpc Endpoint.
func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) { func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
// /foo.Bar/Service // /foo.Bar/Service
if req.URL.Path == "/" { if req.URL.Path == "/" {
@@ -29,10 +31,12 @@ func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
}, nil }, nil
} }
// String returns the name of the resolver.
func (r *Resolver) String() string { func (r *Resolver) String() string {
return "grpc" return "grpc"
} }
// NewResolver creates a new gRPC resolver.
func NewResolver(opts ...resolver.Option) resolver.Resolver { func NewResolver(opts ...resolver.Option) resolver.Resolver {
return &Resolver{} return &Resolver{}
} }

View File

@@ -7,10 +7,12 @@ import (
"go-micro.dev/v4/api/resolver" "go-micro.dev/v4/api/resolver"
) )
// Resolver is a host resolver.
type Resolver struct { type Resolver struct {
opts resolver.Options opts resolver.Options
} }
// Resolve resolves a http.Request to an grpc Endpoint.
func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) { func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
return &resolver.Endpoint{ return &resolver.Endpoint{
Name: req.Host, Name: req.Host,
@@ -20,10 +22,12 @@ func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
}, nil }, nil
} }
// String returns the name of the resolver.
func (r *Resolver) String() string { func (r *Resolver) String() string {
return "host" return "host"
} }
// NewResolver creates a new host resolver.
func NewResolver(opts ...resolver.Option) resolver.Resolver { func NewResolver(opts ...resolver.Option) resolver.Resolver {
return &Resolver{opts: resolver.NewOptions(opts...)} return &Resolver{opts: resolver.NewOptions(opts...)}
} }

View File

@@ -4,6 +4,7 @@ import (
"net/http" "net/http"
) )
// NewOptions wires options together.
func NewOptions(opts ...Option) Options { func NewOptions(opts ...Option) Options {
var options Options var options Options
for _, o := range opts { for _, o := range opts {

View File

@@ -8,10 +8,12 @@ import (
"go-micro.dev/v4/api/resolver" "go-micro.dev/v4/api/resolver"
) )
// Resolver is a path resolver.
type Resolver struct { type Resolver struct {
opts resolver.Options opts resolver.Options
} }
// Resolve resolves a http.Request to an grpc Endpoint.
func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) { func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
if req.URL.Path == "/" { if req.URL.Path == "/" {
return nil, resolver.ErrNotFound return nil, resolver.ErrNotFound
@@ -32,6 +34,7 @@ func (r *Resolver) String() string {
return "path" return "path"
} }
// NewResolver returns a new path resolver.
func NewResolver(opts ...resolver.Option) resolver.Resolver { func NewResolver(opts ...resolver.Option) resolver.Resolver {
return &Resolver{opts: resolver.NewOptions(opts...)} return &Resolver{opts: resolver.NewOptions(opts...)}
} }

View File

@@ -29,11 +29,13 @@ type Endpoint struct {
Path string Path string
} }
// Options is a struct of available options.
type Options struct { type Options struct {
Handler string Handler string
Namespace func(*http.Request) string Namespace func(*http.Request) string
} }
// Option is a helper for a single option.
type Option func(o *Options) type Option func(o *Options)
// StaticNamespace returns the same namespace for each request. // StaticNamespace returns the same namespace for each request.

View File

@@ -10,10 +10,12 @@ import (
"go-micro.dev/v4/api/resolver" "go-micro.dev/v4/api/resolver"
) )
// NewResolver returns a new vpath resolver.
func NewResolver(opts ...resolver.Option) resolver.Resolver { func NewResolver(opts ...resolver.Option) resolver.Resolver {
return &Resolver{opts: resolver.NewOptions(opts...)} return &Resolver{opts: resolver.NewOptions(opts...)}
} }
// Resolver is a vpath resolver.
type Resolver struct { type Resolver struct {
opts resolver.Options opts resolver.Options
} }
@@ -22,6 +24,7 @@ var (
re = regexp.MustCompile("^v[0-9]+$") re = regexp.MustCompile("^v[0-9]+$")
) )
// Resolve resolves a http.Request to an grpc Endpoint.
func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) { func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
if req.URL.Path == "/" { if req.URL.Path == "/" {
return nil, errors.New("unknown name") return nil, errors.New("unknown name")

View File

@@ -7,6 +7,7 @@ import (
"go-micro.dev/v4/registry" "go-micro.dev/v4/registry"
) )
// Options is a struct of options available.
type Options struct { type Options struct {
Handler string Handler string
Registry registry.Registry Registry registry.Registry
@@ -14,8 +15,10 @@ type Options struct {
Logger logger.Logger Logger logger.Logger
} }
// Option is a helper for a single options.
type Option func(o *Options) type Option func(o *Options)
// NewOptions wires options together.
func NewOptions(opts ...Option) Options { func NewOptions(opts ...Option) Options {
options := Options{ options := Options{
Handler: "meta", Handler: "meta",

View File

@@ -51,14 +51,17 @@ func (r *registryRouter) isStopped() bool {
// refresh list of api services. // refresh list of api services.
func (r *registryRouter) refresh() { func (r *registryRouter) refresh() {
var attempts int var attempts int
logger := r.Options().Logger logger := r.Options().Logger
for { for {
services, err := r.opts.Registry.ListServices() services, err := r.opts.Registry.ListServices()
if err != nil { if err != nil {
attempts++ attempts++
logger.Logf(log.ErrorLevel, "unable to list services: %v", err) logger.Logf(log.ErrorLevel, "unable to list services: %v", err)
time.Sleep(time.Duration(attempts) * time.Second) time.Sleep(time.Duration(attempts) * time.Second)
continue continue
} }
@@ -71,6 +74,7 @@ func (r *registryRouter) refresh() {
logger.Logf(log.ErrorLevel, "unable to get service: %v", err) logger.Logf(log.ErrorLevel, "unable to get service: %v", err)
continue continue
} }
r.store(service) r.store(service)
} }
@@ -169,11 +173,13 @@ func (r *registryRouter) store(services []*registry.Service) {
if h == "" || h == "*" { if h == "" || h == "*" {
continue continue
} }
hostreg, err := regexp.CompilePOSIX(h) hostreg, err := regexp.CompilePOSIX(h)
if err != nil { if err != nil {
logger.Logf(log.TraceLevel, "endpoint have invalid host regexp: %v", err) logger.Logf(log.TraceLevel, "endpoint have invalid host regexp: %v", err)
continue continue
} }
cep.hostregs = append(cep.hostregs, hostreg) cep.hostregs = append(cep.hostregs, hostreg)
} }
@@ -197,11 +203,13 @@ func (r *registryRouter) store(services []*registry.Service) {
} }
tpl := rule.Compile() tpl := rule.Compile()
pathreg, err := util.NewPattern(tpl.Version, tpl.OpCodes, tpl.Pool, "", util.PatternLogger(logger)) pathreg, err := util.NewPattern(tpl.Version, tpl.OpCodes, tpl.Pool, "", util.PatternLogger(logger))
if err != nil { if err != nil {
logger.Logf(log.TraceLevel, "endpoint have invalid path pattern: %v", err) logger.Logf(log.TraceLevel, "endpoint have invalid path pattern: %v", err)
continue continue
} }
cep.pathregs = append(cep.pathregs, pathreg) cep.pathregs = append(cep.pathregs, pathreg)
} }
@@ -212,6 +220,7 @@ func (r *registryRouter) store(services []*registry.Service) {
// watch for endpoint changes. // watch for endpoint changes.
func (r *registryRouter) watch() { func (r *registryRouter) watch() {
var attempts int var attempts int
logger := r.Options().Logger logger := r.Options().Logger
for { for {
@@ -223,8 +232,10 @@ func (r *registryRouter) watch() {
w, err := r.opts.Registry.Watch() w, err := r.opts.Registry.Watch()
if err != nil { if err != nil {
attempts++ attempts++
logger.Logf(log.ErrorLevel, "error watching endpoints: %v", err) logger.Logf(log.ErrorLevel, "error watching endpoints: %v", err)
time.Sleep(time.Duration(attempts) * time.Second) time.Sleep(time.Duration(attempts) * time.Second)
continue continue
} }
@@ -248,8 +259,10 @@ func (r *registryRouter) watch() {
if err != nil { if err != nil {
logger.Logf(log.ErrorLevel, "error getting next endoint: %v", err) logger.Logf(log.ErrorLevel, "error getting next endoint: %v", err)
close(ch) close(ch)
break break
} }
r.process(res) r.process(res)
} }
} }
@@ -267,6 +280,7 @@ func (r *registryRouter) Stop() error {
close(r.exit) close(r.exit)
r.rc.Stop() r.rc.Stop()
} }
return nil return nil
} }
@@ -280,6 +294,7 @@ func (r *registryRouter) Deregister(ep *router.Route) error {
func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) { func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
logger := r.Options().Logger logger := r.Options().Logger
if r.isStopped() { if r.isStopped() {
return nil, errors.New("router closed") return nil, errors.New("router closed")
} }
@@ -291,16 +306,19 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
if len(req.URL.Path) > 0 && req.URL.Path != "/" { if len(req.URL.Path) > 0 && req.URL.Path != "/" {
idx = 1 idx = 1
} }
path := strings.Split(req.URL.Path[idx:], "/") path := strings.Split(req.URL.Path[idx:], "/")
// use the first match // use the first match
// TODO: weighted matching // TODO: weighted matching
for n, e := range r.eps { for n, endpoint := range r.eps {
cep, ok := r.ceps[n] cep, ok := r.ceps[n]
if !ok { if !ok {
continue continue
} }
ep := e.Endpoint
ep := endpoint.Endpoint
var mMatch, hMatch, pMatch bool var mMatch, hMatch, pMatch bool
// 1. try method // 1. try method
for _, m := range ep.Method { for _, m := range ep.Method {
@@ -309,6 +327,7 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
break break
} }
} }
if !mMatch { if !mMatch {
continue continue
} }
@@ -323,14 +342,13 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
if h == "" || h == "*" { if h == "" || h == "*" {
hMatch = true hMatch = true
break break
} else { } else if cep.hostregs[idx].MatchString(req.URL.Host) {
if cep.hostregs[idx].MatchString(req.URL.Host) { hMatch = true
hMatch = true break
break
}
} }
} }
} }
if !hMatch { if !hMatch {
continue continue
} }
@@ -344,17 +362,23 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
logger.Logf(log.DebugLevel, "api gpath not match %s != %v", path, pathreg) logger.Logf(log.DebugLevel, "api gpath not match %s != %v", path, pathreg)
continue continue
} }
logger.Logf(log.DebugLevel, "api gpath match %s = %v", path, pathreg) logger.Logf(log.DebugLevel, "api gpath match %s = %v", path, pathreg)
pMatch = true pMatch = true
ctx := req.Context() ctx := req.Context()
md, ok := metadata.FromContext(ctx) md, ok := metadata.FromContext(ctx)
if !ok { if !ok {
md = make(metadata.Metadata) md = make(metadata.Metadata)
} }
for k, v := range matches { for k, v := range matches {
md[fmt.Sprintf("x-api-field-%s", k)] = v md[fmt.Sprintf("x-api-field-%s", k)] = v
} }
*req = *req.Clone(metadata.NewContext(ctx, md)) *req = *req.Clone(metadata.NewContext(ctx, md))
break break
} }
@@ -365,8 +389,11 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
logger.Logf(log.DebugLevel, "api pcre path not match %s != %v", path, pathreg) logger.Logf(log.DebugLevel, "api pcre path not match %s != %v", path, pathreg)
continue continue
} }
logger.Logf(log.DebugLevel, "api pcre path match %s != %v", path, pathreg) logger.Logf(log.DebugLevel, "api pcre path match %s != %v", path, pathreg)
pMatch = true pMatch = true
break break
} }
} }
@@ -377,7 +404,7 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
// TODO: Percentage traffic // TODO: Percentage traffic
// we got here, so its a match // we got here, so its a match
return e, nil return endpoint, nil
} }
// no match // no match
@@ -400,13 +427,13 @@ func (r *registryRouter) Route(req *http.Request) (*router.Route, error) {
// TODO: don't ignore that shit // TODO: don't ignore that shit
// get the service name // get the service name
rp, err := r.opts.Resolver.Resolve(req) rsp, err := r.opts.Resolver.Resolve(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// service name // service name
name := rp.Name name := rsp.Name
// get service // get service
services, err := r.rc.GetService(name) services, err := r.rc.GetService(name)
@@ -429,7 +456,7 @@ func (r *registryRouter) Route(req *http.Request) (*router.Route, error) {
return &router.Route{ return &router.Route{
Service: name, Service: name,
Endpoint: &router.Endpoint{ Endpoint: &router.Endpoint{
Name: rp.Method, Name: rsp.Method,
Handler: handler, Handler: handler,
}, },
Versions: services, Versions: services,
@@ -462,8 +489,10 @@ func newRouter(opts ...router.Option) *registryRouter {
eps: make(map[string]*router.Route), eps: make(map[string]*router.Route),
ceps: make(map[string]*endpoint), ceps: make(map[string]*endpoint),
} }
go r.watch() go r.watch()
go r.refresh() go r.refresh()
return r return r
} }

View File

@@ -23,15 +23,15 @@ type endpoint struct {
pcreregs []*regexp.Regexp pcreregs []*regexp.Regexp
} }
// router is the default router. // Router is the default router.
type staticRouter struct { type Router struct {
exit chan bool exit chan bool
opts router.Options opts router.Options
sync.RWMutex sync.RWMutex
eps map[string]*endpoint eps map[string]*endpoint
} }
func (r *staticRouter) isStopd() bool { func (r *Router) isStopd() bool {
select { select {
case <-r.exit: case <-r.exit:
return true return true
@@ -87,18 +87,20 @@ func (r *staticRouter) watch() {
} }
*/ */
func (r *staticRouter) Register(route *router.Route) error { func (r *Router) Register(route *router.Route) error {
ep := route.Endpoint myEndpoint := route.Endpoint
if err := router.Validate(ep); err != nil { if err := router.Validate(myEndpoint); err != nil {
return err return err
} }
var pathregs []util.Pattern var (
var hostregs []*regexp.Regexp pathregs []util.Pattern
var pcreregs []*regexp.Regexp hostregs []*regexp.Regexp
pcreregs []*regexp.Regexp
)
for _, h := range ep.Host { for _, h := range myEndpoint.Host {
if h == "" || h == "*" { if h == "" || h == "*" {
continue continue
} }
@@ -106,10 +108,11 @@ func (r *staticRouter) Register(route *router.Route) error {
if err != nil { if err != nil {
return err return err
} }
hostregs = append(hostregs, hostreg) hostregs = append(hostregs, hostreg)
} }
for _, p := range ep.Path { for _, p := range myEndpoint.Path {
var pcreok bool var pcreok bool
// pcre only when we have start and end markers // pcre only when we have start and end markers
@@ -129,63 +132,70 @@ func (r *staticRouter) Register(route *router.Route) error {
} }
tpl := rule.Compile() tpl := rule.Compile()
pathreg, err := util.NewPattern(tpl.Version, tpl.OpCodes, tpl.Pool, "", util.PatternLogger(r.Options().Logger)) pathreg, err := util.NewPattern(tpl.Version, tpl.OpCodes, tpl.Pool, "", util.PatternLogger(r.Options().Logger))
if err != nil { if err != nil {
return err return err
} }
pathregs = append(pathregs, pathreg) pathregs = append(pathregs, pathreg)
} }
r.Lock() r.Lock()
r.eps[ep.Name] = &endpoint{ r.eps[myEndpoint.Name] = &endpoint{
apiep: ep, apiep: myEndpoint,
pcreregs: pcreregs, pcreregs: pcreregs,
pathregs: pathregs, pathregs: pathregs,
hostregs: hostregs, hostregs: hostregs,
} }
r.Unlock() r.Unlock()
return nil return nil
} }
func (r *staticRouter) Deregister(route *router.Route) error { func (r *Router) Deregister(route *router.Route) error {
ep := route.Endpoint ep := route.Endpoint
if err := router.Validate(ep); err != nil { if err := router.Validate(ep); err != nil {
return err return err
} }
r.Lock() r.Lock()
delete(r.eps, ep.Name) delete(r.eps, ep.Name)
r.Unlock() r.Unlock()
return nil return nil
} }
func (r *staticRouter) Options() router.Options { func (r *Router) Options() router.Options {
return r.opts return r.opts
} }
func (r *staticRouter) Stop() error { func (r *Router) Stop() error {
select { select {
case <-r.exit: case <-r.exit:
return nil return nil
default: default:
close(r.exit) close(r.exit)
} }
return nil return nil
} }
func (r *staticRouter) Endpoint(req *http.Request) (*router.Route, error) { func (r *Router) Endpoint(req *http.Request) (*router.Route, error) {
ep, err := r.endpoint(req) myEndpoint, err := r.endpoint(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
epf := strings.Split(ep.apiep.Name, ".") epf := strings.Split(myEndpoint.apiep.Name, ".")
services, err := r.opts.Registry.GetService(epf[0]) services, err := r.opts.Registry.GetService(epf[0])
if err != nil { if err != nil {
return nil, err return nil, err
} }
// hack for stream endpoint // hack for stream endpoint
if ep.apiep.Stream { if myEndpoint.apiep.Stream {
svcs := rutil.Copy(services) svcs := rutil.Copy(services)
for _, svc := range svcs { for _, svc := range svcs {
if len(svc.Endpoints) == 0 { if len(svc.Endpoints) == 0 {
@@ -195,6 +205,7 @@ func (r *staticRouter) Endpoint(req *http.Request) (*router.Route, error) {
e.Metadata["stream"] = "true" e.Metadata["stream"] = "true"
svc.Endpoints = append(svc.Endpoints, e) svc.Endpoints = append(svc.Endpoints, e)
} }
for _, e := range svc.Endpoints { for _, e := range svc.Endpoints {
e.Name = strings.Join(epf[1:], ".") e.Name = strings.Join(epf[1:], ".")
e.Metadata = make(map[string]string) e.Metadata = make(map[string]string)
@@ -210,10 +221,10 @@ func (r *staticRouter) Endpoint(req *http.Request) (*router.Route, error) {
Endpoint: &router.Endpoint{ Endpoint: &router.Endpoint{
Name: strings.Join(epf[1:], "."), Name: strings.Join(epf[1:], "."),
Handler: "rpc", Handler: "rpc",
Host: ep.apiep.Host, Host: myEndpoint.apiep.Host,
Method: ep.apiep.Method, Method: myEndpoint.apiep.Method,
Path: ep.apiep.Path, Path: myEndpoint.apiep.Path,
Stream: ep.apiep.Stream, Stream: myEndpoint.apiep.Stream,
}, },
Versions: services, Versions: services,
} }
@@ -221,8 +232,9 @@ func (r *staticRouter) Endpoint(req *http.Request) (*router.Route, error) {
return svc, nil return svc, nil
} }
func (r *staticRouter) endpoint(req *http.Request) (*endpoint, error) { func (r *Router) endpoint(req *http.Request) (*endpoint, error) {
logger := r.Options().Logger logger := r.Options().Logger
if r.isStopd() { if r.isStopd() {
return nil, errors.New("router closed") return nil, errors.New("router closed")
} }
@@ -234,75 +246,85 @@ func (r *staticRouter) endpoint(req *http.Request) (*endpoint, error) {
if len(req.URL.Path) > 0 && req.URL.Path != "/" { if len(req.URL.Path) > 0 && req.URL.Path != "/" {
idx = 1 idx = 1
} }
path := strings.Split(req.URL.Path[idx:], "/") path := strings.Split(req.URL.Path[idx:], "/")
// use the first match // use the first match
// TODO: weighted matching // TODO: weighted matching
for _, myEndpoint := range r.eps {
for _, ep := range r.eps {
var mMatch, hMatch, pMatch bool var mMatch, hMatch, pMatch bool
// 1. try method // 1. try method
for _, m := range ep.apiep.Method { for _, m := range myEndpoint.apiep.Method {
if m == req.Method { if m == req.Method {
mMatch = true mMatch = true
break break
} }
} }
if !mMatch { if !mMatch {
continue continue
} }
logger.Logf(log.DebugLevel, "api method match %s", req.Method) logger.Logf(log.DebugLevel, "api method match %s", req.Method)
// 2. try host // 2. try host
if len(ep.apiep.Host) == 0 { if len(myEndpoint.apiep.Host) == 0 {
hMatch = true hMatch = true
} else { } else {
for idx, h := range ep.apiep.Host { for idx, h := range myEndpoint.apiep.Host {
if h == "" || h == "*" { if h == "" || h == "*" {
hMatch = true hMatch = true
break break
} else { } else if myEndpoint.hostregs[idx].MatchString(req.URL.Host) {
if ep.hostregs[idx].MatchString(req.URL.Host) { hMatch = true
hMatch = true break
break
}
} }
} }
} }
if !hMatch { if !hMatch {
continue continue
} }
logger.Logf(log.DebugLevel, "api host match %s", req.URL.Host) logger.Logf(log.DebugLevel, "api host match %s", req.URL.Host)
// 3. try google.api path // 3. try google.api path
for _, pathreg := range ep.pathregs { for _, pathreg := range myEndpoint.pathregs {
matches, err := pathreg.Match(path, "") matches, err := pathreg.Match(path, "")
if err != nil { if err != nil {
logger.Logf(log.DebugLevel, "api gpath not match %s != %v", path, pathreg) logger.Logf(log.DebugLevel, "api gpath not match %s != %v", path, pathreg)
continue continue
} }
logger.Logf(log.DebugLevel, "api gpath match %s = %v", path, pathreg) logger.Logf(log.DebugLevel, "api gpath match %s = %v", path, pathreg)
pMatch = true pMatch = true
ctx := req.Context() ctx := req.Context()
md, ok := metadata.FromContext(ctx) md, ok := metadata.FromContext(ctx)
if !ok { if !ok {
md = make(metadata.Metadata) md = make(metadata.Metadata)
} }
for k, v := range matches { for k, v := range matches {
md[fmt.Sprintf("x-api-field-%s", k)] = v md[fmt.Sprintf("x-api-field-%s", k)] = v
} }
*req = *req.Clone(metadata.NewContext(ctx, md)) *req = *req.Clone(metadata.NewContext(ctx, md))
break break
} }
if !pMatch { if !pMatch {
// 4. try path via pcre path matching // 4. try path via pcre path matching
for _, pathreg := range ep.pcreregs { for _, pathreg := range myEndpoint.pcreregs {
if !pathreg.MatchString(req.URL.Path) { if !pathreg.MatchString(req.URL.Path) {
logger.Logf(log.DebugLevel, "api pcre path not match %s != %v", req.URL.Path, pathreg) logger.Logf(log.DebugLevel, "api pcre path not match %s != %v", req.URL.Path, pathreg)
continue continue
} }
pMatch = true pMatch = true
break break
} }
} }
@@ -313,14 +335,14 @@ func (r *staticRouter) endpoint(req *http.Request) (*endpoint, error) {
// TODO: Percentage traffic // TODO: Percentage traffic
// we got here, so its a match // we got here, so its a match
return ep, nil return myEndpoint, nil
} }
// no match // no match
return nil, fmt.Errorf("endpoint not found for %v", req.URL) return nil, fmt.Errorf("endpoint not found for %v", req.URL)
} }
func (r *staticRouter) Route(req *http.Request) (*router.Route, error) { func (r *Router) Route(req *http.Request) (*router.Route, error) {
if r.isStopd() { if r.isStopd() {
return nil, errors.New("router closed") return nil, errors.New("router closed")
} }
@@ -334,9 +356,10 @@ func (r *staticRouter) Route(req *http.Request) (*router.Route, error) {
return ep, nil return ep, nil
} }
func NewRouter(opts ...router.Option) *staticRouter { // NewRouter returns a new static router.
func NewRouter(opts ...router.Option) *Router {
options := router.NewOptions(opts...) options := router.NewOptions(opts...)
r := &staticRouter{ r := &Router{
exit: make(chan bool), exit: make(chan bool),
opts: options, opts: options,
eps: make(map[string]*endpoint), eps: make(map[string]*endpoint),

View File

@@ -1,6 +1,7 @@
package util package util
// download from https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/compile.go // download from
// https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/compile.go
const ( const (
opcodeVersion = 1 opcodeVersion = 1
@@ -66,6 +67,7 @@ func (v variable) compile() []op {
for _, s := range v.segments { for _, s := range v.segments {
ops = append(ops, s.compile()...) ops = append(ops, s.compile()...)
} }
ops = append(ops, op{ ops = append(ops, op{
code: OpConcatN, code: OpConcatN,
operand: len(v.segments), operand: len(v.segments),
@@ -88,7 +90,9 @@ func (t template) Compile() Template {
pool []string pool []string
fields []string fields []string
) )
consts := make(map[string]int) consts := make(map[string]int)
for _, op := range rawOps { for _, op := range rawOps {
ops = append(ops, int(op.code)) ops = append(ops, int(op.code))
if op.str == "" { if op.str == "" {
@@ -100,10 +104,12 @@ func (t template) Compile() Template {
} }
ops = append(ops, consts[op.str]) ops = append(ops, consts[op.str])
} }
if op.code == OpCapture { if op.code == OpCapture {
fields = append(fields, op.str) fields = append(fields, op.str)
} }
} }
return Template{ return Template{
Version: opcodeVersion, Version: opcodeVersion,
OpCodes: ops, OpCodes: ops,

View File

@@ -1,6 +1,7 @@
package util package util
// download from https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/parse.go // download from
// https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/parse.go
import ( import (
"fmt" "fmt"
@@ -24,9 +25,10 @@ func Parse(tmpl string) (Compiler, error) {
if !strings.HasPrefix(tmpl, "/") { if !strings.HasPrefix(tmpl, "/") {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: "no leading /"} return template{}, InvalidTemplateError{tmpl: tmpl, msg: "no leading /"}
} }
tokens, verb := tokenize(tmpl[1:])
tokens, verb := tokenize(tmpl[1:])
p := parser{tokens: tokens} p := parser{tokens: tokens}
segs, err := p.topLevelSegments() segs, err := p.topLevelSegments()
if err != nil { if err != nil {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: err.Error()} return template{}, InvalidTemplateError{tmpl: tmpl, msg: err.Error()}
@@ -52,8 +54,10 @@ func tokenize(path string) (tokens []string, verb string) {
var ( var (
st = init st = init
) )
for path != "" { for path != "" {
var idx int var idx int
switch st { switch st {
case init: case init:
idx = strings.IndexAny(path, "/{") idx = strings.IndexAny(path, "/{")
@@ -62,10 +66,12 @@ func tokenize(path string) (tokens []string, verb string) {
case nested: case nested:
idx = strings.IndexAny(path, "/}") idx = strings.IndexAny(path, "/}")
} }
if idx < 0 { if idx < 0 {
tokens = append(tokens, path) tokens = append(tokens, path)
break break
} }
switch r := path[idx]; r { switch r := path[idx]; r {
case '/', '.': case '/', '.':
case '{': case '{':
@@ -75,22 +81,27 @@ func tokenize(path string) (tokens []string, verb string) {
case '}': case '}':
st = init st = init
} }
if idx == 0 { if idx == 0 {
tokens = append(tokens, path[idx:idx+1]) tokens = append(tokens, path[idx:idx+1])
} else { } else {
tokens = append(tokens, path[:idx], path[idx:idx+1]) tokens = append(tokens, path[:idx], path[idx:idx+1])
} }
path = path[idx+1:] path = path[idx+1:]
} }
l := len(tokens) l := len(tokens)
t := tokens[l-1] t := tokens[l-1]
if idx := strings.LastIndex(t, ":"); idx == 0 { if idx := strings.LastIndex(t, ":"); idx == 0 {
tokens, verb = tokens[:l-1], t[1:] tokens, verb = tokens[:l-1], t[1:]
} else if idx > 0 { } else if idx > 0 {
tokens[l-1], verb = t[:idx], t[idx+1:] tokens[l-1], verb = t[:idx], t[idx+1:]
} }
tokens = append(tokens, eof) tokens = append(tokens, eof)
return tokens, verb return tokens, verb
} }
@@ -111,17 +122,22 @@ func (p *parser) topLevelSegments() ([]segment, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
logger.Logf(log.DebugLevel, "accept segments: %q; %q", p.accepted, p.tokens) logger.Logf(log.DebugLevel, "accept segments: %q; %q", p.accepted, p.tokens)
if _, err := p.accept(typeEOF); err != nil { if _, err := p.accept(typeEOF); err != nil {
return nil, fmt.Errorf("unexpected token %q after segments %q", p.tokens[0], strings.Join(p.accepted, "")) return nil, fmt.Errorf("unexpected token %q after segments %q", p.tokens[0], strings.Join(p.accepted, ""))
} }
logger.Logf(log.DebugLevel, "accept eof: %q; %q", p.accepted, p.tokens) logger.Logf(log.DebugLevel, "accept eof: %q; %q", p.accepted, p.tokens)
return segs, nil return segs, nil
} }
func (p *parser) segments() ([]segment, error) { func (p *parser) segments() ([]segment, error) {
logger := log.LoggerOrDefault(p.logger) logger := log.LoggerOrDefault(p.logger)
s, err := p.segment() s, err := p.segment()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -133,11 +149,14 @@ func (p *parser) segments() ([]segment, error) {
if _, err := p.accept("/"); err != nil { if _, err := p.accept("/"); err != nil {
return segs, nil return segs, nil
} }
s, err := p.segment() s, err := p.segment()
if err != nil { if err != nil {
return segs, err return segs, err
} }
segs = append(segs, s) segs = append(segs, s)
logger.Logf(log.DebugLevel, "accept segment: %q; %q", p.accepted, p.tokens) logger.Logf(log.DebugLevel, "accept segment: %q; %q", p.accepted, p.tokens)
} }
} }
@@ -146,17 +165,20 @@ func (p *parser) segment() (segment, error) {
if _, err := p.accept("*"); err == nil { if _, err := p.accept("*"); err == nil {
return wildcard{}, nil return wildcard{}, nil
} }
if _, err := p.accept("**"); err == nil { if _, err := p.accept("**"); err == nil {
return deepWildcard{}, nil return deepWildcard{}, nil
} }
if l, err := p.literal(); err == nil { if l, err := p.literal(); err == nil {
return l, nil return l, nil
} }
v, err := p.variable() v, err := p.variable()
if err != nil { if err != nil {
return nil, fmt.Errorf("segment neither wildcards, literal or variable: %v", err) return nil, fmt.Errorf("segment neither wildcards, literal or variable: %w", err)
} }
return v, err return v, err
} }
@@ -165,6 +187,7 @@ func (p *parser) literal() (segment, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return literal(lit), nil return literal(lit), nil
} }
@@ -182,7 +205,7 @@ func (p *parser) variable() (segment, error) {
if _, err := p.accept("="); err == nil { if _, err := p.accept("="); err == nil {
segs, err = p.segments() segs, err = p.segments()
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid segment in variable %q: %v", path, err) return nil, fmt.Errorf("invalid segment in variable %q: %w", path, err)
} }
} else { } else {
segs = []segment{wildcard{}} segs = []segment{wildcard{}}
@@ -191,6 +214,7 @@ func (p *parser) variable() (segment, error) {
if _, err := p.accept("}"); err != nil { if _, err := p.accept("}"); err != nil {
return nil, fmt.Errorf("unterminated variable segment: %s", path) return nil, fmt.Errorf("unterminated variable segment: %s", path)
} }
return variable{ return variable{
path: path, path: path,
segments: segs, segments: segs,
@@ -202,7 +226,9 @@ func (p *parser) fieldPath() (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
components := []string{c} components := []string{c}
for { for {
if _, err = p.accept("."); err != nil { if _, err = p.accept("."); err != nil {
return strings.Join(components, "."), nil return strings.Join(components, "."), nil
@@ -238,6 +264,7 @@ const (
// If it doesn't match, the function does not consume any tokens and return an error. // If it doesn't match, the function does not consume any tokens and return an error.
func (p *parser) accept(term termType) (string, error) { func (p *parser) accept(term termType) (string, error) {
t := p.tokens[0] t := p.tokens[0]
switch term { switch term {
case "/", "*", "**", ".", "=", "{", "}": case "/", "*", "**", ".", "=", "{", "}":
if t != string(term) && t != "/" { if t != string(term) && t != "/" {
@@ -258,8 +285,10 @@ func (p *parser) accept(term termType) (string, error) {
default: default:
return "", fmt.Errorf("unknown termType %q", term) return "", fmt.Errorf("unknown termType %q", term)
} }
p.tokens = p.tokens[1:] p.tokens = p.tokens[1:]
p.accepted = append(p.accepted, t) p.accepted = append(p.accepted, t)
return t, nil return t, nil
} }
@@ -278,6 +307,7 @@ func expectPChars(t string) error {
pct1 pct1
pct2 pct2
) )
st := init st := init
for _, r := range t { for _, r := range t {
if st != init { if st != init {
@@ -316,9 +346,11 @@ func expectPChars(t string) error {
return fmt.Errorf("invalid character in path segment: %q(%U)", r, r) return fmt.Errorf("invalid character in path segment: %q(%U)", r, r)
} }
} }
if st != init { if st != init {
return fmt.Errorf("invalid percent-encoding in %q", t) return fmt.Errorf("invalid percent-encoding in %q", t)
} }
return nil return nil
} }
@@ -327,12 +359,14 @@ func expectIdent(ident string) error {
if ident == "" { if ident == "" {
return fmt.Errorf("empty identifier") return fmt.Errorf("empty identifier")
} }
for pos, r := range ident { for pos, r := range ident {
switch { switch {
case '0' <= r && r <= '9': case '0' <= r && r <= '9':
if pos == 0 { if pos == 0 {
return fmt.Errorf("identifier starting with digit: %s", ident) return fmt.Errorf("identifier starting with digit: %s", ident)
} }
continue continue
case 'A' <= r && r <= 'Z': case 'A' <= r && r <= 'Z':
continue continue
@@ -344,6 +378,7 @@ func expectIdent(ident string) error {
return fmt.Errorf("invalid character %q(%U) in identifier: %s", r, r, ident) return fmt.Errorf("invalid character %q(%U) in identifier: %s", r, r, ident)
} }
} }
return nil return nil
} }
@@ -356,5 +391,6 @@ func isHexDigit(r rune) bool {
case 'a' <= r && r <= 'f': case 'a' <= r && r <= 'f':
return true return true
} }
return false return false
} }

View File

@@ -49,7 +49,7 @@ type patternOptions struct {
// PatternOpt is an option for creating Patterns. // PatternOpt is an option for creating Patterns.
type PatternOpt func(*patternOptions) type PatternOpt func(*patternOptions)
// Logger sets the logger. // PatternLogger sets the logger.
func PatternLogger(l log.Logger) PatternOpt { func PatternLogger(l log.Logger) PatternOpt {
return func(po *patternOptions) { return func(po *patternOptions) {
po.logger = l po.logger = l
@@ -89,8 +89,10 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
pushMSeen bool pushMSeen bool
vars []string vars []string
) )
for i := 0; i < l; i += 2 { for i := 0; i < l; i += 2 {
op := rop{code: OpCode(ops[i]), operand: ops[i+1]} op := rop{code: OpCode(ops[i]), operand: ops[i+1]}
switch op.code { switch op.code {
case OpNop: case OpNop:
continue continue
@@ -104,6 +106,7 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
logger.Logf(log.DebugLevel, "pushM appears twice") logger.Logf(log.DebugLevel, "pushM appears twice")
return Pattern{}, ErrInvalidPattern return Pattern{}, ErrInvalidPattern
} }
pushMSeen = true pushMSeen = true
stack++ stack++
case OpLitPush: case OpLitPush:
@@ -111,6 +114,7 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
logger.Logf(log.DebugLevel, "negative literal index: %d", op.operand) logger.Logf(log.DebugLevel, "negative literal index: %d", op.operand)
return Pattern{}, ErrInvalidPattern return Pattern{}, ErrInvalidPattern
} }
if pushMSeen { if pushMSeen {
tailLen++ tailLen++
} }
@@ -120,6 +124,7 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
logger.Logf(log.DebugLevel, "negative concat size: %d", op.operand) logger.Logf(log.DebugLevel, "negative concat size: %d", op.operand)
return Pattern{}, ErrInvalidPattern return Pattern{}, ErrInvalidPattern
} }
stack -= op.operand stack -= op.operand
if stack < 0 { if stack < 0 {
logger.Logf(log.DebugLevel, "stack underflow") logger.Logf(log.DebugLevel, "stack underflow")
@@ -131,10 +136,12 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
logger.Logf(log.DebugLevel, "variable name index out of bound: %d", op.operand) logger.Logf(log.DebugLevel, "variable name index out of bound: %d", op.operand)
return Pattern{}, ErrInvalidPattern return Pattern{}, ErrInvalidPattern
} }
v := pool[op.operand] v := pool[op.operand]
op.operand = len(vars) op.operand = len(vars)
vars = append(vars, v) vars = append(vars, v)
stack-- stack--
if stack < 0 { if stack < 0 {
logger.Logf(log.DebugLevel, "stack underflow") logger.Logf(log.DebugLevel, "stack underflow")
return Pattern{}, ErrInvalidPattern return Pattern{}, ErrInvalidPattern
@@ -147,8 +154,10 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
if maxstack < stack { if maxstack < stack {
maxstack = stack maxstack = stack
} }
typedOps = append(typedOps, op) typedOps = append(typedOps, op)
} }
return Pattern{ return Pattern{
ops: typedOps, ops: typedOps,
pool: pool, pool: pool,

View File

@@ -1,6 +1,7 @@
package util package util
// download from https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/types.go // download from
// https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/types.go
import ( import (
"fmt" "fmt"
@@ -46,6 +47,7 @@ func (v variable) String() string {
for _, s := range v.segments { for _, s := range v.segments {
segs = append(segs, s.String()) segs = append(segs, s.String())
} }
return fmt.Sprintf("{%s=%s}", v.path, strings.Join(segs, "/")) return fmt.Sprintf("{%s=%s}", v.path, strings.Join(segs, "/"))
} }
@@ -54,9 +56,11 @@ func (t template) String() string {
for _, s := range t.segments { for _, s := range t.segments {
segs = append(segs, s.String()) segs = append(segs, s.String())
} }
str := strings.Join(segs, "/") str := strings.Join(segs, "/")
if t.verb != "" { if t.verb != "" {
str = fmt.Sprintf("%s:%s", str, t.verb) str = fmt.Sprintf("%s:%s", str, t.verb)
} }
return "/" + str return "/" + str
} }

View File

@@ -1,6 +1,7 @@
package util package util
// download from https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/types_test.go // download from
// https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/types_test.go
import ( import (
"fmt" "fmt"