1
0
mirror of https://github.com/go-micro/go-micro.git synced 2024-12-24 10:07:04 +02:00

fix: easy lint fixes to api/ (#2567)

This commit is contained in:
Rene Jochum 2022-10-01 10:50:11 +02:00 committed by GitHub
parent 010b1d9f11
commit 065f9714e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 380 additions and 165 deletions

View File

@ -164,7 +164,7 @@ func (client *Client) Call(service, endpoint string, request, response interface
// Stream enables the ability to stream via websockets.
func (client *Client) Stream(service, endpoint string, request interface{}) (*Stream, error) {
b, err := marshalRequest(endpoint, request)
bytes, err := marshalRequest(endpoint, request)
if err != nil {
return nil, err
}
@ -196,7 +196,7 @@ func (client *Client) Stream(service, endpoint string, request interface{}) (*St
}
// send the first request
if err := conn.WriteMessage(websocket.TextMessage, b); err != nil {
if err := conn.WriteMessage(websocket.TextMessage, bytes); err != nil {
return nil, err
}
@ -224,7 +224,7 @@ func (s *Stream) Send(v interface{}) error {
return s.conn.WriteMessage(websocket.TextMessage, b)
}
func marshalRequest(endpoint string, v interface{}) ([]byte, error) {
func marshalRequest(_ string, v interface{}) ([]byte, error) {
return json.Marshal(v)
}

View File

@ -18,6 +18,7 @@ type apiHandler struct {
}
const (
// Handler is the name of the Handler.
Handler = "api"
)
@ -29,12 +30,15 @@ func (a *apiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
r.Body = http.MaxBytesReader(w, r.Body, bsize)
request, err := requestToProto(r)
if err != nil {
er := errors.InternalServerError("go.micro.api", err.Error())
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(er.Error()))
return
}
@ -45,18 +49,23 @@ func (a *apiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s, err := a.opts.Router.Route(r)
if err != nil {
er := errors.InternalServerError("go.micro.api", err.Error())
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(er.Error()))
return
}
service = s
} else {
// we have no way of routing the request
er := errors.InternalServerError("go.micro.api", "no route found")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(er.Error()))
return
}
@ -72,14 +81,17 @@ func (a *apiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := c.Call(cx, req, rsp, client.WithSelectOption(so)); err != nil {
w.Header().Set("Content-Type", "application/json")
ce := errors.Parse(err.Error())
switch ce.Code {
case 0:
w.WriteHeader(500)
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(int(ce.Code))
}
w.Write([]byte(ce.Error()))
return
} else if rsp.StatusCode == 0 {
rsp.StatusCode = http.StatusOK
@ -96,6 +108,7 @@ func (a *apiHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
w.WriteHeader(int(rsp.StatusCode))
w.Write([]byte(rsp.Body))
}
@ -103,8 +116,10 @@ func (a *apiHandler) String() string {
return "api"
}
// NewHandler returns an api.Handler.
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.NewOptions(opts...)
return &apiHandler{
opts: options,
}

View File

@ -20,7 +20,7 @@ var (
func requestToProto(r *http.Request) (*api.Request, error) {
if err := r.ParseForm(); err != nil {
return nil, fmt.Errorf("Error parsing form: %v", err)
return nil, fmt.Errorf("Error parsing form: %w", err)
}
req := &api.Request{
@ -46,9 +46,11 @@ func requestToProto(r *http.Request) (*api.Request, error) {
default:
buf := bufferPool.Get()
defer bufferPool.Put(buf)
if _, err = buf.ReadFrom(r.Body); err != nil {
return nil, err
}
req.Body = buf.String()
}
}
@ -81,6 +83,7 @@ func requestToProto(r *http.Request) (*api.Request, error) {
}
req.Get[key] = header
}
header.Values = vals
}
@ -93,6 +96,7 @@ func requestToProto(r *http.Request) (*api.Request, error) {
}
req.Post[key] = header
}
header.Values = vals
}
@ -104,6 +108,7 @@ func requestToProto(r *http.Request) (*api.Request, error) {
}
req.Header[key] = header
}
header.Values = vals
}

View File

@ -26,6 +26,7 @@ type event struct {
}
var (
// Handler is the name of this handler.
Handler = "event"
versionRe = regexp.MustCompilePOSIX("^v[0-9]+$")
)
@ -34,55 +35,56 @@ func eventName(parts []string) string {
return strings.Join(parts, ".")
}
func evRoute(ns, p string) (string, string) {
p = path.Clean(p)
p = strings.TrimPrefix(p, "/")
func evRoute(namespace, myPath string) (string, string) {
myPath = path.Clean(myPath)
myPath = strings.TrimPrefix(myPath, "/")
if len(p) == 0 {
return ns, "event"
if len(myPath) == 0 {
return namespace, Handler
}
parts := strings.Split(p, "/")
parts := strings.Split(myPath, "/")
// no path
if len(parts) == 0 {
// topic: namespace
// action: event
return strings.Trim(ns, "."), "event"
return strings.Trim(namespace, "."), Handler
}
// Treat /v[0-9]+ as versioning
// /v1/foo/bar => topic: v1.foo action: bar
if len(parts) >= 2 && versionRe.Match([]byte(parts[0])) {
topic := ns + "." + strings.Join(parts[:2], ".")
topic := namespace + "." + strings.Join(parts[:2], ".")
action := eventName(parts[1:])
return topic, action
}
// /foo => topic: ns.foo action: foo
// /foo/bar => topic: ns.foo action: bar
topic := ns + "." + strings.Join(parts[:1], ".")
topic := namespace + "." + strings.Join(parts[:1], ".")
action := eventName(parts[1:])
return topic, action
}
func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (e *event) ServeHTTP(rsp http.ResponseWriter, req *http.Request) {
bsize := handler.DefaultMaxRecvSize
if e.opts.MaxRecvSize > 0 {
bsize = e.opts.MaxRecvSize
}
r.Body = http.MaxBytesReader(w, r.Body, bsize)
req.Body = http.MaxBytesReader(rsp, req.Body, bsize)
// request to topic:event
// create event
// publish to topic
topic, action := evRoute(e.opts.Namespace, r.URL.Path)
topic, action := evRoute(e.opts.Namespace, req.URL.Path)
// create event
ev := &proto.Event{
event := &proto.Event{
Name: action,
// TODO: dedupe event
Id: fmt.Sprintf("%s-%s-%s", topic, action, uuid.New().String()),
@ -91,49 +93,53 @@ func (e *event) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// set headers
for key, vals := range r.Header {
header, ok := ev.Header[key]
for key, vals := range req.Header {
header, ok := event.Header[key]
if !ok {
header = &proto.Pair{
Key: key,
}
ev.Header[key] = header
event.Header[key] = header
}
header.Values = vals
}
// set body
if r.Method == "GET" {
bytes, _ := json.Marshal(r.URL.Query())
ev.Data = string(bytes)
if req.Method == http.MethodGet {
bytes, _ := json.Marshal(req.URL.Query())
event.Data = string(bytes)
} else {
// Read body
buf := bufferPool.Get()
defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil {
http.Error(w, err.Error(), 500)
if _, err := buf.ReadFrom(req.Body); err != nil {
http.Error(rsp, err.Error(), http.StatusInternalServerError)
return
}
ev.Data = buf.String()
event.Data = buf.String()
}
// get client
c := e.opts.Client
// create publication
p := c.NewMessage(topic, ev)
p := c.NewMessage(topic, event)
// publish event
if err := c.Publish(ctx.FromRequest(r), p); err != nil {
http.Error(w, err.Error(), 500)
if err := c.Publish(ctx.FromRequest(req), p); err != nil {
http.Error(rsp, err.Error(), http.StatusInternalServerError)
return
}
}
func (e *event) String() string {
return "event"
return Handler
}
// NewHandler returns a new event handler.
func NewHandler(opts ...handler.Option) handler.Handler {
return &event{
opts: handler.NewOptions(opts...),

View File

@ -14,6 +14,7 @@ import (
)
const (
// Handler is the name of the handler.
Handler = "http"
)
@ -24,18 +25,18 @@ type httpHandler struct {
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
service, err := h.getService(r)
if err != nil {
w.WriteHeader(500)
w.WriteHeader(http.StatusInternalServerError)
return
}
if len(service) == 0 {
w.WriteHeader(404)
w.WriteHeader(http.StatusNotFound)
return
}
rp, err := url.Parse(service)
if err != nil {
w.WriteHeader(500)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -52,6 +53,7 @@ func (h *httpHandler) getService(r *http.Request) (string, error) {
if err != nil {
return "", err
}
service = s
} else {
// we have no way of routing the request

View File

@ -7,9 +7,11 @@ import (
)
var (
DefaultMaxRecvSize int64 = 1024 * 1024 * 100 // 10Mb
// DefaultMaxRecvSize is 10MiB.
DefaultMaxRecvSize int64 = 1024 * 1024 * 100
)
// Options is the list of api Options.
type Options struct {
MaxRecvSize int64
Namespace string
@ -18,6 +20,7 @@ type Options struct {
Logger logger.Logger
}
// Option is a api Option.
type Option func(o *Options)
// NewOptions fills in the blanks.
@ -59,6 +62,7 @@ func WithRouter(r router.Router) Option {
}
}
// WithClient sets the client for the handler.
func WithClient(c client.Client) Option {
return func(o *Options) {
o.Client = c

View File

@ -28,6 +28,7 @@ import (
)
const (
// Handler is the name of this handler.
Handler = "rpc"
packageID = "go.micro.api"
)
@ -76,6 +77,7 @@ func strategy(services []*registry.Service) selector.Strategy {
func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
logger := h.opts.Logger
bsize := handler.DefaultMaxRecvSize
if h.opts.MaxRecvSize > 0 {
bsize = h.opts.MaxRecvSize
}
@ -83,6 +85,7 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, bsize)
defer r.Body.Close()
var service *router.Route
if h.opts.Router != nil {
@ -93,8 +96,10 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if werr != nil {
logger.Log(log.ErrorLevel, werr)
}
return
}
service = s
} else {
// we have no way of routing the request
@ -105,18 +110,18 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
ct := r.Header.Get("Content-Type")
contentType := r.Header.Get("Content-Type")
// Strip charset from Content-Type (like `application/json; charset=UTF-8`)
if idx := strings.IndexRune(ct, ';'); idx >= 0 {
ct = ct[:idx]
if idx := strings.IndexRune(contentType, ';'); idx >= 0 {
contentType = contentType[:idx]
}
// micro client
c := h.opts.Client
myClient := h.opts.Client
// create context
cx := ctx.FromRequest(r)
myContext := ctx.FromRequest(r)
// get context from http handler wrappers
md, ok := metadata.FromContext(r.Context())
if !ok {
@ -132,23 +137,24 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// merge context with overwrite
cx = metadata.MergeContext(cx, md, true)
myContext = metadata.MergeContext(myContext, md, true)
// set merged context to request
*r = *r.Clone(cx)
*r = *r.Clone(myContext)
// if stream we currently only support json
if isStream(r, service) {
// drop older context as it can have timeouts and create new
// md, _ := metadata.FromContext(cx)
// serveWebsocket(context.TODO(), w, r, service, c)
if err := serveWebsocket(cx, w, r, service, c); err != nil {
if err := serveWebsocket(myContext, w, r, service, myClient); err != nil {
logger.Log(log.ErrorLevel, err)
}
return
}
// create strategy
so := selector.WithStrategy(strategy(service.Versions))
mySelector := selector.WithStrategy(strategy(service.Versions))
// walk the standard call path
// get payload
@ -157,6 +163,7 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if werr := writeError(w, r, err); werr != nil {
logger.Log(log.ErrorLevel, werr)
}
return
}
@ -164,7 +171,7 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch {
// proto codecs
case hasCodec(ct, protoCodecs):
case hasCodec(contentType, protoCodecs):
request := &proto.Message{}
// if the extracted payload isn't empty lets use it
if len(br) > 0 {
@ -174,18 +181,19 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// create request/response
response := &proto.Message{}
req := c.NewRequest(
req := myClient.NewRequest(
service.Service,
service.Endpoint.Name,
request,
client.WithContentType(ct),
client.WithContentType(contentType),
)
// make the call
if err := c.Call(cx, req, response, client.WithSelectOption(so)); err != nil {
if err := myClient.Call(myContext, req, response, client.WithSelectOption(mySelector)); err != nil {
if werr := writeError(w, r, err); werr != nil {
logger.Log(log.ErrorLevel, werr)
}
return
}
@ -195,13 +203,14 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if werr := writeError(w, r, err); werr != nil {
logger.Log(log.ErrorLevel, werr)
}
return
}
default:
// if json codec is not present set to json
if !hasCodec(ct, jsonCodecs) {
ct = "application/json"
if !hasCodec(contentType, jsonCodecs) {
contentType = "application/json"
}
// default to trying json
@ -214,17 +223,18 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// create request/response
var response json.RawMessage
req := c.NewRequest(
req := myClient.NewRequest(
service.Service,
service.Endpoint.Name,
&request,
client.WithContentType(ct),
client.WithContentType(contentType),
)
// make the call
if err := c.Call(cx, req, &response, client.WithSelectOption(so)); err != nil {
if err := myClient.Call(myContext, req, &response, client.WithSelectOption(mySelector)); err != nil {
if werr := writeError(w, r, err); werr != nil {
logger.Log(log.ErrorLevel, werr)
}
return
}
@ -234,6 +244,7 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if werr := writeError(w, r, err); werr != nil {
logger.Log(log.ErrorLevel, werr)
}
return
}
}
@ -244,8 +255,8 @@ func (h *rpcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (rh *rpcHandler) String() string {
return "rpc"
func (h *rpcHandler) String() string {
return Handler
}
func hasCodec(ct string, codecs []string) bool {
@ -254,6 +265,7 @@ func hasCodec(ct string, codecs []string) bool {
return true
}
}
return false
}
@ -266,37 +278,44 @@ func requestPayload(r *http.Request) ([]byte, error) {
// we have to decode json-rpc and proto-rpc because we suck
// well actually because there's no proxy codec right now
ct := r.Header.Get("Content-Type")
myCt := r.Header.Get("Content-Type")
switch {
case strings.Contains(ct, "application/json-rpc"):
case strings.Contains(myCt, "application/json-rpc"):
msg := codec.Message{
Type: codec.Request,
Header: make(map[string]string),
}
c := jsonrpc.NewCodec(&buffer{r.Body})
if err = c.ReadHeader(&msg, codec.Request); err != nil {
return nil, err
}
var raw json.RawMessage
if err = c.ReadBody(&raw); err != nil {
return nil, err
}
return ([]byte)(raw), nil
case strings.Contains(ct, "application/proto-rpc"), strings.Contains(ct, "application/octet-stream"):
case strings.Contains(myCt, "application/proto-rpc"), strings.Contains(myCt, "application/octet-stream"):
msg := codec.Message{
Type: codec.Request,
Header: make(map[string]string),
}
c := protorpc.NewCodec(&buffer{r.Body})
if err = c.ReadHeader(&msg, codec.Request); err != nil {
return nil, err
}
var raw proto.Message
if err = c.ReadBody(&raw); err != nil {
return nil, err
}
return raw.Marshal()
case strings.Contains(ct, "application/www-x-form-urlencoded"):
case strings.Contains(myCt, "application/www-x-form-urlencoded"):
if err := r.ParseForm(); err != nil {
return nil, err
}
@ -314,6 +333,7 @@ func requestPayload(r *http.Request) ([]byte, error) {
// otherwise as per usual
ctx := r.Context()
// dont user meadata.FromContext as it mangles names
md, ok := metadata.FromContext(ctx)
if !ok {
@ -330,9 +350,11 @@ func requestPayload(r *http.Request) ([]byte, error) {
// filter own keys
if strings.HasPrefix(k, "x-api-field-") {
matches[strings.TrimPrefix(k, "x-api-field-")] = v
delete(md, k)
} else if k == "x-api-body" {
bodydst = v
delete(md, k)
}
}
@ -343,10 +365,12 @@ func requestPayload(r *http.Request) ([]byte, error) {
// get fields from url values
if len(r.URL.RawQuery) > 0 {
umd := make(map[string]interface{})
err = qson.Unmarshal(&umd, r.URL.RawQuery)
if err != nil {
return nil, err
}
for k, v := range umd {
matches[k] = v
}
@ -361,24 +385,29 @@ func requestPayload(r *http.Request) ([]byte, error) {
req[k] = v
continue
}
em := make(map[string]interface{})
em[ps[len(ps)-1]] = v
for i := len(ps) - 2; i > 0; i-- {
nm := make(map[string]interface{})
nm[ps[i]] = em
em = nm
}
if vm, ok := req[ps[0]]; ok {
// nested map
nm := vm.(map[string]interface{})
for vk, vv := range em {
nm[vk] = vv
}
req[ps[0]] = nm
} else {
req[ps[0]] = em
}
}
pathbuf := []byte("{}")
if len(req) > 0 {
pathbuf, err = json.Marshal(req)
@ -388,41 +417,48 @@ func requestPayload(r *http.Request) ([]byte, error) {
}
urlbuf := []byte("{}")
out, err := jsonpatch.MergeMergePatches(urlbuf, pathbuf)
if err != nil {
return nil, err
}
switch r.Method {
case "GET":
case http.MethodGet:
// empty response
if strings.Contains(ct, "application/json") && string(out) == "{}" {
if strings.Contains(myCt, "application/json") && string(out) == "{}" {
return out, nil
} else if string(out) == "{}" && !strings.Contains(ct, "application/json") {
} else if string(out) == "{}" && !strings.Contains(myCt, "application/json") {
return []byte{}, nil
}
return out, nil
case "PATCH", "POST", "PUT", "DELETE":
case http.MethodPatch, http.MethodPost, http.MethodPut, http.MethodDelete:
bodybuf := []byte("{}")
buf := bufferPool.Get()
defer bufferPool.Put(buf)
if _, err := buf.ReadFrom(r.Body); err != nil {
return nil, err
}
if b := buf.Bytes(); len(b) > 0 {
bodybuf = b
}
if bodydst == "" || bodydst == "*" {
if out, err = jsonpatch.MergeMergePatches(out, bodybuf); err == nil {
return out, nil
}
}
var jsonbody map[string]interface{}
if json.Valid(bodybuf) {
if err = json.Unmarshal(bodybuf, &jsonbody); err != nil {
return nil, err
}
}
dstmap := make(map[string]interface{})
ps := strings.Split(bodydst, ".")
if len(ps) == 1 {
@ -463,33 +499,35 @@ func requestPayload(r *http.Request) ([]byte, error) {
return []byte{}, nil
}
func writeError(w http.ResponseWriter, r *http.Request, err error) error {
func writeError(rsp http.ResponseWriter, req *http.Request, err error) error {
ce := errors.Parse(err.Error())
switch ce.Code {
case 0:
// assuming it's totally screwed
ce.Code = 500
ce.Code = http.StatusInternalServerError
ce.Id = packageID
ce.Status = http.StatusText(500)
ce.Status = http.StatusText(http.StatusInternalServerError)
ce.Detail = "error during request: " + ce.Detail
w.WriteHeader(500)
rsp.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(int(ce.Code))
rsp.WriteHeader(int(ce.Code))
}
// response content type
w.Header().Set("Content-Type", "application/json")
rsp.Header().Set("Content-Type", "application/json")
// Set trailers
if strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
w.Header().Set("Trailer", "grpc-status")
w.Header().Set("Trailer", "grpc-message")
w.Header().Set("grpc-status", "13")
w.Header().Set("grpc-message", ce.Detail)
if strings.Contains(req.Header.Get("Content-Type"), "application/grpc") {
rsp.Header().Set("Trailer", "grpc-status")
rsp.Header().Set("Trailer", "grpc-message")
rsp.Header().Set("grpc-status", "13")
rsp.Header().Set("grpc-message", ce.Detail)
}
_, werr := w.Write([]byte(ce.Error()))
_, werr := rsp.Write([]byte(ce.Error()))
return werr
}
@ -512,11 +550,14 @@ func writeResponse(w http.ResponseWriter, r *http.Request, rsp []byte) error {
// write response
_, err := w.Write(rsp)
return err
}
// NewHandler returns a new RPC handler.
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.NewOptions(opts...)
return &rpcHandler{
opts: options,
}

View File

@ -20,32 +20,33 @@ import (
// serveWebsocket will stream rpc back over websockets assuming json.
func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *router.Route, c client.Client) (err error) {
var op ws.OpCode
var opCode ws.OpCode
ct := r.Header.Get("Content-Type")
myCt := r.Header.Get("Content-Type")
// Strip charset from Content-Type (like `application/json; charset=UTF-8`)
if idx := strings.IndexRune(ct, ';'); idx >= 0 {
ct = ct[:idx]
if idx := strings.IndexRune(myCt, ';'); idx >= 0 {
myCt = myCt[:idx]
}
// check proto from request
switch ct {
switch myCt {
case "application/json":
op = ws.OpText
opCode = ws.OpText
default:
op = ws.OpBinary
opCode = ws.OpBinary
}
hdr := make(http.Header)
if proto, ok := r.Header["Sec-Websocket-Protocol"]; ok {
for _, p := range proto {
switch p {
case "binary":
if p == "binary" {
hdr["Sec-WebSocket-Protocol"] = []string{"binary"}
op = ws.OpBinary
opCode = ws.OpBinary
}
}
}
payload, err := requestPayload(r)
if err != nil {
return
@ -66,7 +67,7 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
Header: hdr,
}
conn, rw, _, err := upgrader.Upgrade(r, w)
conn, uRw, _, err := upgrader.Upgrade(r, w)
if err != nil {
return
}
@ -78,8 +79,9 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
}()
var request interface{}
if !bytes.Equal(payload, []byte(`{}`)) {
switch ct {
switch myCt {
case "application/json", "":
m := json.RawMessage(payload)
request = &m
@ -89,14 +91,15 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
}
// we always need to set content type for message
if ct == "" {
ct = "application/json"
if myCt == "" {
myCt = "application/json"
}
req := c.NewRequest(
service.Service,
service.Endpoint.Name,
request,
client.WithContentType(ct),
client.WithContentType(myCt),
client.StreamingRequest(),
)
@ -114,7 +117,7 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
}
go func() {
if wErr := writeLoop(rw, stream); wErr != nil && err == nil {
if wErr := writeLoop(uRw, stream); wErr != nil && err == nil {
err = wErr
}
}()
@ -136,14 +139,16 @@ func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request,
if strings.Contains(err.Error(), "context canceled") {
return nil
}
return err
}
// write the response
if err = wsutil.WriteServerMessage(rw, op, buf); err != nil {
if err = wsutil.WriteServerMessage(uRw, opCode, buf); err != nil {
return err
}
if err = rw.Flush(); err != nil {
if err = uRw.Flush(); err != nil {
return err
}
}
@ -170,10 +175,12 @@ func writeLoop(rw io.ReadWriter, stream client.Stream) error {
case ws.StatusNormalClosure, ws.StatusNoStatusRcvd:
// this happens when user close ws connection, or we don't get any status
return nil
default:
return err
}
}
return err
}
switch op {
default:
// not relevant
@ -210,6 +217,7 @@ func isStream(r *http.Request, srv *router.Route) bool {
}
}
}
return false
}
@ -221,6 +229,7 @@ func isWebSocket(r *http.Request) bool {
return true
}
}
return false
}

View File

@ -17,6 +17,7 @@ import (
)
const (
// Handler is the name of the handler.
Handler = "web"
)
@ -27,18 +28,18 @@ type webHandler struct {
func (wh *webHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
service, err := wh.getService(r)
if err != nil {
w.WriteHeader(500)
w.WriteHeader(http.StatusInternalServerError)
return
}
if len(service) == 0 {
w.WriteHeader(404)
w.WriteHeader(http.StatusNotFound)
return
}
rp, err := url.Parse(service)
if err != nil {
w.WriteHeader(500)
w.WriteHeader(http.StatusInternalServerError)
return
}
@ -60,6 +61,7 @@ func (wh *webHandler) getService(r *http.Request) (string, error) {
if err != nil {
return "", err
}
service = s
} else {
// we have no way of routing the request
@ -79,12 +81,12 @@ func (wh *webHandler) getService(r *http.Request) (string, error) {
}
// serveWebSocket used to serve a web socket proxied connection.
func (wh *webHandler) serveWebSocket(host string, w http.ResponseWriter, r *http.Request) {
func (wh *webHandler) serveWebSocket(host string, rsp http.ResponseWriter, r *http.Request) {
req := new(http.Request)
*req = *r
if len(host) == 0 {
http.Error(w, "invalid host", 500)
http.Error(rsp, "invalid host", http.StatusInternalServerError)
return
}
@ -93,20 +95,21 @@ func (wh *webHandler) serveWebSocket(host string, w http.ResponseWriter, r *http
if ips, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(ips, ", ") + ", " + clientIP
}
req.Header.Set("X-Forwarded-For", clientIP)
}
// connect to the backend host
conn, err := net.Dial("tcp", host)
if err != nil {
http.Error(w, err.Error(), 500)
http.Error(rsp, err.Error(), http.StatusInternalServerError)
return
}
// hijack the connection
hj, ok := w.(http.Hijacker)
hj, ok := rsp.(http.Hijacker)
if !ok {
http.Error(w, "failed to connect", 500)
http.Error(rsp, "failed to connect", http.StatusInternalServerError)
return
}
@ -143,6 +146,7 @@ func isWebSocket(r *http.Request) bool {
return true
}
}
return false
}
@ -157,6 +161,7 @@ func (wh *webHandler) String() string {
return "web"
}
// NewHandler returns a new web handler.
func NewHandler(opts ...handler.Option) handler.Handler {
return &webHandler{
opts: handler.NewOptions(opts...),

View File

@ -9,8 +9,10 @@ import (
"go-micro.dev/v4/api/resolver"
)
// Resolver is the gRPC Resolver.
type Resolver struct{}
// Resolve resolves a http.Request to an grpc Endpoint.
func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
// /foo.Bar/Service
if req.URL.Path == "/" {
@ -29,10 +31,12 @@ func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
}, nil
}
// String returns the name of the resolver.
func (r *Resolver) String() string {
return "grpc"
}
// NewResolver creates a new gRPC resolver.
func NewResolver(opts ...resolver.Option) resolver.Resolver {
return &Resolver{}
}

View File

@ -7,10 +7,12 @@ import (
"go-micro.dev/v4/api/resolver"
)
// Resolver is a host resolver.
type Resolver struct {
opts resolver.Options
}
// Resolve resolves a http.Request to an grpc Endpoint.
func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
return &resolver.Endpoint{
Name: req.Host,
@ -20,10 +22,12 @@ func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
}, nil
}
// String returns the name of the resolver.
func (r *Resolver) String() string {
return "host"
}
// NewResolver creates a new host resolver.
func NewResolver(opts ...resolver.Option) resolver.Resolver {
return &Resolver{opts: resolver.NewOptions(opts...)}
}

View File

@ -4,6 +4,7 @@ import (
"net/http"
)
// NewOptions wires options together.
func NewOptions(opts ...Option) Options {
var options Options
for _, o := range opts {

View File

@ -8,10 +8,12 @@ import (
"go-micro.dev/v4/api/resolver"
)
// Resolver is a path resolver.
type Resolver struct {
opts resolver.Options
}
// Resolve resolves a http.Request to an grpc Endpoint.
func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
if req.URL.Path == "/" {
return nil, resolver.ErrNotFound
@ -32,6 +34,7 @@ func (r *Resolver) String() string {
return "path"
}
// NewResolver returns a new path resolver.
func NewResolver(opts ...resolver.Option) resolver.Resolver {
return &Resolver{opts: resolver.NewOptions(opts...)}
}

View File

@ -29,11 +29,13 @@ type Endpoint struct {
Path string
}
// Options is a struct of available options.
type Options struct {
Handler string
Namespace func(*http.Request) string
}
// Option is a helper for a single option.
type Option func(o *Options)
// StaticNamespace returns the same namespace for each request.

View File

@ -10,10 +10,12 @@ import (
"go-micro.dev/v4/api/resolver"
)
// NewResolver returns a new vpath resolver.
func NewResolver(opts ...resolver.Option) resolver.Resolver {
return &Resolver{opts: resolver.NewOptions(opts...)}
}
// Resolver is a vpath resolver.
type Resolver struct {
opts resolver.Options
}
@ -22,6 +24,7 @@ var (
re = regexp.MustCompile("^v[0-9]+$")
)
// Resolve resolves a http.Request to an grpc Endpoint.
func (r *Resolver) Resolve(req *http.Request) (*resolver.Endpoint, error) {
if req.URL.Path == "/" {
return nil, errors.New("unknown name")

View File

@ -7,6 +7,7 @@ import (
"go-micro.dev/v4/registry"
)
// Options is a struct of options available.
type Options struct {
Handler string
Registry registry.Registry
@ -14,8 +15,10 @@ type Options struct {
Logger logger.Logger
}
// Option is a helper for a single options.
type Option func(o *Options)
// NewOptions wires options together.
func NewOptions(opts ...Option) Options {
options := Options{
Handler: "meta",

View File

@ -51,14 +51,17 @@ func (r *registryRouter) isStopped() bool {
// refresh list of api services.
func (r *registryRouter) refresh() {
var attempts int
logger := r.Options().Logger
for {
services, err := r.opts.Registry.ListServices()
if err != nil {
attempts++
logger.Logf(log.ErrorLevel, "unable to list services: %v", err)
time.Sleep(time.Duration(attempts) * time.Second)
continue
}
@ -71,6 +74,7 @@ func (r *registryRouter) refresh() {
logger.Logf(log.ErrorLevel, "unable to get service: %v", err)
continue
}
r.store(service)
}
@ -169,11 +173,13 @@ func (r *registryRouter) store(services []*registry.Service) {
if h == "" || h == "*" {
continue
}
hostreg, err := regexp.CompilePOSIX(h)
if err != nil {
logger.Logf(log.TraceLevel, "endpoint have invalid host regexp: %v", err)
continue
}
cep.hostregs = append(cep.hostregs, hostreg)
}
@ -197,11 +203,13 @@ func (r *registryRouter) store(services []*registry.Service) {
}
tpl := rule.Compile()
pathreg, err := util.NewPattern(tpl.Version, tpl.OpCodes, tpl.Pool, "", util.PatternLogger(logger))
if err != nil {
logger.Logf(log.TraceLevel, "endpoint have invalid path pattern: %v", err)
continue
}
cep.pathregs = append(cep.pathregs, pathreg)
}
@ -212,6 +220,7 @@ func (r *registryRouter) store(services []*registry.Service) {
// watch for endpoint changes.
func (r *registryRouter) watch() {
var attempts int
logger := r.Options().Logger
for {
@ -223,8 +232,10 @@ func (r *registryRouter) watch() {
w, err := r.opts.Registry.Watch()
if err != nil {
attempts++
logger.Logf(log.ErrorLevel, "error watching endpoints: %v", err)
time.Sleep(time.Duration(attempts) * time.Second)
continue
}
@ -248,8 +259,10 @@ func (r *registryRouter) watch() {
if err != nil {
logger.Logf(log.ErrorLevel, "error getting next endoint: %v", err)
close(ch)
break
}
r.process(res)
}
}
@ -267,6 +280,7 @@ func (r *registryRouter) Stop() error {
close(r.exit)
r.rc.Stop()
}
return nil
}
@ -280,6 +294,7 @@ func (r *registryRouter) Deregister(ep *router.Route) error {
func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
logger := r.Options().Logger
if r.isStopped() {
return nil, errors.New("router closed")
}
@ -291,16 +306,19 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
if len(req.URL.Path) > 0 && req.URL.Path != "/" {
idx = 1
}
path := strings.Split(req.URL.Path[idx:], "/")
// use the first match
// TODO: weighted matching
for n, e := range r.eps {
for n, endpoint := range r.eps {
cep, ok := r.ceps[n]
if !ok {
continue
}
ep := e.Endpoint
ep := endpoint.Endpoint
var mMatch, hMatch, pMatch bool
// 1. try method
for _, m := range ep.Method {
@ -309,6 +327,7 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
break
}
}
if !mMatch {
continue
}
@ -323,14 +342,13 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
if h == "" || h == "*" {
hMatch = true
break
} else {
if cep.hostregs[idx].MatchString(req.URL.Host) {
hMatch = true
break
}
} else if cep.hostregs[idx].MatchString(req.URL.Host) {
hMatch = true
break
}
}
}
if !hMatch {
continue
}
@ -344,17 +362,23 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
logger.Logf(log.DebugLevel, "api gpath not match %s != %v", path, pathreg)
continue
}
logger.Logf(log.DebugLevel, "api gpath match %s = %v", path, pathreg)
pMatch = true
ctx := req.Context()
md, ok := metadata.FromContext(ctx)
if !ok {
md = make(metadata.Metadata)
}
for k, v := range matches {
md[fmt.Sprintf("x-api-field-%s", k)] = v
}
*req = *req.Clone(metadata.NewContext(ctx, md))
break
}
@ -365,8 +389,11 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
logger.Logf(log.DebugLevel, "api pcre path not match %s != %v", path, pathreg)
continue
}
logger.Logf(log.DebugLevel, "api pcre path match %s != %v", path, pathreg)
pMatch = true
break
}
}
@ -377,7 +404,7 @@ func (r *registryRouter) Endpoint(req *http.Request) (*router.Route, error) {
// TODO: Percentage traffic
// we got here, so its a match
return e, nil
return endpoint, nil
}
// no match
@ -400,13 +427,13 @@ func (r *registryRouter) Route(req *http.Request) (*router.Route, error) {
// TODO: don't ignore that shit
// get the service name
rp, err := r.opts.Resolver.Resolve(req)
rsp, err := r.opts.Resolver.Resolve(req)
if err != nil {
return nil, err
}
// service name
name := rp.Name
name := rsp.Name
// get service
services, err := r.rc.GetService(name)
@ -429,7 +456,7 @@ func (r *registryRouter) Route(req *http.Request) (*router.Route, error) {
return &router.Route{
Service: name,
Endpoint: &router.Endpoint{
Name: rp.Method,
Name: rsp.Method,
Handler: handler,
},
Versions: services,
@ -462,8 +489,10 @@ func newRouter(opts ...router.Option) *registryRouter {
eps: make(map[string]*router.Route),
ceps: make(map[string]*endpoint),
}
go r.watch()
go r.refresh()
return r
}

View File

@ -23,15 +23,15 @@ type endpoint struct {
pcreregs []*regexp.Regexp
}
// router is the default router.
type staticRouter struct {
// Router is the default router.
type Router struct {
exit chan bool
opts router.Options
sync.RWMutex
eps map[string]*endpoint
}
func (r *staticRouter) isStopd() bool {
func (r *Router) isStopd() bool {
select {
case <-r.exit:
return true
@ -87,18 +87,20 @@ func (r *staticRouter) watch() {
}
*/
func (r *staticRouter) Register(route *router.Route) error {
ep := route.Endpoint
func (r *Router) Register(route *router.Route) error {
myEndpoint := route.Endpoint
if err := router.Validate(ep); err != nil {
if err := router.Validate(myEndpoint); err != nil {
return err
}
var pathregs []util.Pattern
var hostregs []*regexp.Regexp
var pcreregs []*regexp.Regexp
var (
pathregs []util.Pattern
hostregs []*regexp.Regexp
pcreregs []*regexp.Regexp
)
for _, h := range ep.Host {
for _, h := range myEndpoint.Host {
if h == "" || h == "*" {
continue
}
@ -106,10 +108,11 @@ func (r *staticRouter) Register(route *router.Route) error {
if err != nil {
return err
}
hostregs = append(hostregs, hostreg)
}
for _, p := range ep.Path {
for _, p := range myEndpoint.Path {
var pcreok bool
// pcre only when we have start and end markers
@ -129,63 +132,70 @@ func (r *staticRouter) Register(route *router.Route) error {
}
tpl := rule.Compile()
pathreg, err := util.NewPattern(tpl.Version, tpl.OpCodes, tpl.Pool, "", util.PatternLogger(r.Options().Logger))
if err != nil {
return err
}
pathregs = append(pathregs, pathreg)
}
r.Lock()
r.eps[ep.Name] = &endpoint{
apiep: ep,
r.eps[myEndpoint.Name] = &endpoint{
apiep: myEndpoint,
pcreregs: pcreregs,
pathregs: pathregs,
hostregs: hostregs,
}
r.Unlock()
return nil
}
func (r *staticRouter) Deregister(route *router.Route) error {
func (r *Router) Deregister(route *router.Route) error {
ep := route.Endpoint
if err := router.Validate(ep); err != nil {
return err
}
r.Lock()
delete(r.eps, ep.Name)
r.Unlock()
return nil
}
func (r *staticRouter) Options() router.Options {
func (r *Router) Options() router.Options {
return r.opts
}
func (r *staticRouter) Stop() error {
func (r *Router) Stop() error {
select {
case <-r.exit:
return nil
default:
close(r.exit)
}
return nil
}
func (r *staticRouter) Endpoint(req *http.Request) (*router.Route, error) {
ep, err := r.endpoint(req)
func (r *Router) Endpoint(req *http.Request) (*router.Route, error) {
myEndpoint, err := r.endpoint(req)
if err != nil {
return nil, err
}
epf := strings.Split(ep.apiep.Name, ".")
epf := strings.Split(myEndpoint.apiep.Name, ".")
services, err := r.opts.Registry.GetService(epf[0])
if err != nil {
return nil, err
}
// hack for stream endpoint
if ep.apiep.Stream {
if myEndpoint.apiep.Stream {
svcs := rutil.Copy(services)
for _, svc := range svcs {
if len(svc.Endpoints) == 0 {
@ -195,6 +205,7 @@ func (r *staticRouter) Endpoint(req *http.Request) (*router.Route, error) {
e.Metadata["stream"] = "true"
svc.Endpoints = append(svc.Endpoints, e)
}
for _, e := range svc.Endpoints {
e.Name = strings.Join(epf[1:], ".")
e.Metadata = make(map[string]string)
@ -210,10 +221,10 @@ func (r *staticRouter) Endpoint(req *http.Request) (*router.Route, error) {
Endpoint: &router.Endpoint{
Name: strings.Join(epf[1:], "."),
Handler: "rpc",
Host: ep.apiep.Host,
Method: ep.apiep.Method,
Path: ep.apiep.Path,
Stream: ep.apiep.Stream,
Host: myEndpoint.apiep.Host,
Method: myEndpoint.apiep.Method,
Path: myEndpoint.apiep.Path,
Stream: myEndpoint.apiep.Stream,
},
Versions: services,
}
@ -221,8 +232,9 @@ func (r *staticRouter) Endpoint(req *http.Request) (*router.Route, error) {
return svc, nil
}
func (r *staticRouter) endpoint(req *http.Request) (*endpoint, error) {
func (r *Router) endpoint(req *http.Request) (*endpoint, error) {
logger := r.Options().Logger
if r.isStopd() {
return nil, errors.New("router closed")
}
@ -234,75 +246,85 @@ func (r *staticRouter) endpoint(req *http.Request) (*endpoint, error) {
if len(req.URL.Path) > 0 && req.URL.Path != "/" {
idx = 1
}
path := strings.Split(req.URL.Path[idx:], "/")
// use the first match
// TODO: weighted matching
for _, ep := range r.eps {
for _, myEndpoint := range r.eps {
var mMatch, hMatch, pMatch bool
// 1. try method
for _, m := range ep.apiep.Method {
for _, m := range myEndpoint.apiep.Method {
if m == req.Method {
mMatch = true
break
}
}
if !mMatch {
continue
}
logger.Logf(log.DebugLevel, "api method match %s", req.Method)
// 2. try host
if len(ep.apiep.Host) == 0 {
if len(myEndpoint.apiep.Host) == 0 {
hMatch = true
} else {
for idx, h := range ep.apiep.Host {
for idx, h := range myEndpoint.apiep.Host {
if h == "" || h == "*" {
hMatch = true
break
} else {
if ep.hostregs[idx].MatchString(req.URL.Host) {
hMatch = true
break
}
} else if myEndpoint.hostregs[idx].MatchString(req.URL.Host) {
hMatch = true
break
}
}
}
if !hMatch {
continue
}
logger.Logf(log.DebugLevel, "api host match %s", req.URL.Host)
// 3. try google.api path
for _, pathreg := range ep.pathregs {
for _, pathreg := range myEndpoint.pathregs {
matches, err := pathreg.Match(path, "")
if err != nil {
logger.Logf(log.DebugLevel, "api gpath not match %s != %v", path, pathreg)
continue
}
logger.Logf(log.DebugLevel, "api gpath match %s = %v", path, pathreg)
pMatch = true
ctx := req.Context()
md, ok := metadata.FromContext(ctx)
if !ok {
md = make(metadata.Metadata)
}
for k, v := range matches {
md[fmt.Sprintf("x-api-field-%s", k)] = v
}
*req = *req.Clone(metadata.NewContext(ctx, md))
break
}
if !pMatch {
// 4. try path via pcre path matching
for _, pathreg := range ep.pcreregs {
for _, pathreg := range myEndpoint.pcreregs {
if !pathreg.MatchString(req.URL.Path) {
logger.Logf(log.DebugLevel, "api pcre path not match %s != %v", req.URL.Path, pathreg)
continue
}
pMatch = true
break
}
}
@ -313,14 +335,14 @@ func (r *staticRouter) endpoint(req *http.Request) (*endpoint, error) {
// TODO: Percentage traffic
// we got here, so its a match
return ep, nil
return myEndpoint, nil
}
// no match
return nil, fmt.Errorf("endpoint not found for %v", req.URL)
}
func (r *staticRouter) Route(req *http.Request) (*router.Route, error) {
func (r *Router) Route(req *http.Request) (*router.Route, error) {
if r.isStopd() {
return nil, errors.New("router closed")
}
@ -334,9 +356,10 @@ func (r *staticRouter) Route(req *http.Request) (*router.Route, error) {
return ep, nil
}
func NewRouter(opts ...router.Option) *staticRouter {
// NewRouter returns a new static router.
func NewRouter(opts ...router.Option) *Router {
options := router.NewOptions(opts...)
r := &staticRouter{
r := &Router{
exit: make(chan bool),
opts: options,
eps: make(map[string]*endpoint),

View File

@ -1,6 +1,7 @@
package util
// download from https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/compile.go
// download from
// https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/compile.go
const (
opcodeVersion = 1
@ -66,6 +67,7 @@ func (v variable) compile() []op {
for _, s := range v.segments {
ops = append(ops, s.compile()...)
}
ops = append(ops, op{
code: OpConcatN,
operand: len(v.segments),
@ -88,7 +90,9 @@ func (t template) Compile() Template {
pool []string
fields []string
)
consts := make(map[string]int)
for _, op := range rawOps {
ops = append(ops, int(op.code))
if op.str == "" {
@ -100,10 +104,12 @@ func (t template) Compile() Template {
}
ops = append(ops, consts[op.str])
}
if op.code == OpCapture {
fields = append(fields, op.str)
}
}
return Template{
Version: opcodeVersion,
OpCodes: ops,

View File

@ -1,6 +1,7 @@
package util
// download from https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/parse.go
// download from
// https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/parse.go
import (
"fmt"
@ -24,9 +25,10 @@ func Parse(tmpl string) (Compiler, error) {
if !strings.HasPrefix(tmpl, "/") {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: "no leading /"}
}
tokens, verb := tokenize(tmpl[1:])
tokens, verb := tokenize(tmpl[1:])
p := parser{tokens: tokens}
segs, err := p.topLevelSegments()
if err != nil {
return template{}, InvalidTemplateError{tmpl: tmpl, msg: err.Error()}
@ -52,8 +54,10 @@ func tokenize(path string) (tokens []string, verb string) {
var (
st = init
)
for path != "" {
var idx int
switch st {
case init:
idx = strings.IndexAny(path, "/{")
@ -62,10 +66,12 @@ func tokenize(path string) (tokens []string, verb string) {
case nested:
idx = strings.IndexAny(path, "/}")
}
if idx < 0 {
tokens = append(tokens, path)
break
}
switch r := path[idx]; r {
case '/', '.':
case '{':
@ -75,22 +81,27 @@ func tokenize(path string) (tokens []string, verb string) {
case '}':
st = init
}
if idx == 0 {
tokens = append(tokens, path[idx:idx+1])
} else {
tokens = append(tokens, path[:idx], path[idx:idx+1])
}
path = path[idx+1:]
}
l := len(tokens)
t := tokens[l-1]
if idx := strings.LastIndex(t, ":"); idx == 0 {
tokens, verb = tokens[:l-1], t[1:]
} else if idx > 0 {
tokens[l-1], verb = t[:idx], t[idx+1:]
}
tokens = append(tokens, eof)
return tokens, verb
}
@ -111,17 +122,22 @@ func (p *parser) topLevelSegments() ([]segment, error) {
if err != nil {
return nil, err
}
logger.Logf(log.DebugLevel, "accept segments: %q; %q", p.accepted, p.tokens)
if _, err := p.accept(typeEOF); err != nil {
return nil, fmt.Errorf("unexpected token %q after segments %q", p.tokens[0], strings.Join(p.accepted, ""))
}
logger.Logf(log.DebugLevel, "accept eof: %q; %q", p.accepted, p.tokens)
return segs, nil
}
func (p *parser) segments() ([]segment, error) {
logger := log.LoggerOrDefault(p.logger)
s, err := p.segment()
if err != nil {
return nil, err
}
@ -133,11 +149,14 @@ func (p *parser) segments() ([]segment, error) {
if _, err := p.accept("/"); err != nil {
return segs, nil
}
s, err := p.segment()
if err != nil {
return segs, err
}
segs = append(segs, s)
logger.Logf(log.DebugLevel, "accept segment: %q; %q", p.accepted, p.tokens)
}
}
@ -146,17 +165,20 @@ func (p *parser) segment() (segment, error) {
if _, err := p.accept("*"); err == nil {
return wildcard{}, nil
}
if _, err := p.accept("**"); err == nil {
return deepWildcard{}, nil
}
if l, err := p.literal(); err == nil {
return l, nil
}
v, err := p.variable()
if err != nil {
return nil, fmt.Errorf("segment neither wildcards, literal or variable: %v", err)
return nil, fmt.Errorf("segment neither wildcards, literal or variable: %w", err)
}
return v, err
}
@ -165,6 +187,7 @@ func (p *parser) literal() (segment, error) {
if err != nil {
return nil, err
}
return literal(lit), nil
}
@ -182,7 +205,7 @@ func (p *parser) variable() (segment, error) {
if _, err := p.accept("="); err == nil {
segs, err = p.segments()
if err != nil {
return nil, fmt.Errorf("invalid segment in variable %q: %v", path, err)
return nil, fmt.Errorf("invalid segment in variable %q: %w", path, err)
}
} else {
segs = []segment{wildcard{}}
@ -191,6 +214,7 @@ func (p *parser) variable() (segment, error) {
if _, err := p.accept("}"); err != nil {
return nil, fmt.Errorf("unterminated variable segment: %s", path)
}
return variable{
path: path,
segments: segs,
@ -202,7 +226,9 @@ func (p *parser) fieldPath() (string, error) {
if err != nil {
return "", err
}
components := []string{c}
for {
if _, err = p.accept("."); err != nil {
return strings.Join(components, "."), nil
@ -238,6 +264,7 @@ const (
// If it doesn't match, the function does not consume any tokens and return an error.
func (p *parser) accept(term termType) (string, error) {
t := p.tokens[0]
switch term {
case "/", "*", "**", ".", "=", "{", "}":
if t != string(term) && t != "/" {
@ -258,8 +285,10 @@ func (p *parser) accept(term termType) (string, error) {
default:
return "", fmt.Errorf("unknown termType %q", term)
}
p.tokens = p.tokens[1:]
p.accepted = append(p.accepted, t)
return t, nil
}
@ -278,6 +307,7 @@ func expectPChars(t string) error {
pct1
pct2
)
st := init
for _, r := range t {
if st != init {
@ -316,9 +346,11 @@ func expectPChars(t string) error {
return fmt.Errorf("invalid character in path segment: %q(%U)", r, r)
}
}
if st != init {
return fmt.Errorf("invalid percent-encoding in %q", t)
}
return nil
}
@ -327,12 +359,14 @@ func expectIdent(ident string) error {
if ident == "" {
return fmt.Errorf("empty identifier")
}
for pos, r := range ident {
switch {
case '0' <= r && r <= '9':
if pos == 0 {
return fmt.Errorf("identifier starting with digit: %s", ident)
}
continue
case 'A' <= r && r <= 'Z':
continue
@ -344,6 +378,7 @@ func expectIdent(ident string) error {
return fmt.Errorf("invalid character %q(%U) in identifier: %s", r, r, ident)
}
}
return nil
}
@ -356,5 +391,6 @@ func isHexDigit(r rune) bool {
case 'a' <= r && r <= 'f':
return true
}
return false
}

View File

@ -49,7 +49,7 @@ type patternOptions struct {
// PatternOpt is an option for creating Patterns.
type PatternOpt func(*patternOptions)
// Logger sets the logger.
// PatternLogger sets the logger.
func PatternLogger(l log.Logger) PatternOpt {
return func(po *patternOptions) {
po.logger = l
@ -89,8 +89,10 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
pushMSeen bool
vars []string
)
for i := 0; i < l; i += 2 {
op := rop{code: OpCode(ops[i]), operand: ops[i+1]}
switch op.code {
case OpNop:
continue
@ -104,6 +106,7 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
logger.Logf(log.DebugLevel, "pushM appears twice")
return Pattern{}, ErrInvalidPattern
}
pushMSeen = true
stack++
case OpLitPush:
@ -111,6 +114,7 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
logger.Logf(log.DebugLevel, "negative literal index: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
if pushMSeen {
tailLen++
}
@ -120,6 +124,7 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
logger.Logf(log.DebugLevel, "negative concat size: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
stack -= op.operand
if stack < 0 {
logger.Logf(log.DebugLevel, "stack underflow")
@ -131,10 +136,12 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
logger.Logf(log.DebugLevel, "variable name index out of bound: %d", op.operand)
return Pattern{}, ErrInvalidPattern
}
v := pool[op.operand]
op.operand = len(vars)
vars = append(vars, v)
stack--
if stack < 0 {
logger.Logf(log.DebugLevel, "stack underflow")
return Pattern{}, ErrInvalidPattern
@ -147,8 +154,10 @@ func NewPattern(version int, ops []int, pool []string, verb string, opts ...Patt
if maxstack < stack {
maxstack = stack
}
typedOps = append(typedOps, op)
}
return Pattern{
ops: typedOps,
pool: pool,

View File

@ -1,6 +1,7 @@
package util
// download from https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/types.go
// download from
// https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/types.go
import (
"fmt"
@ -46,6 +47,7 @@ func (v variable) String() string {
for _, s := range v.segments {
segs = append(segs, s.String())
}
return fmt.Sprintf("{%s=%s}", v.path, strings.Join(segs, "/"))
}
@ -54,9 +56,11 @@ func (t template) String() string {
for _, s := range t.segments {
segs = append(segs, s.String())
}
str := strings.Join(segs, "/")
if t.verb != "" {
str = fmt.Sprintf("%s:%s", str, t.verb)
}
return "/" + str
}

View File

@ -1,6 +1,7 @@
package util
// download from https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/types_test.go
// download from
// https://raw.githubusercontent.com/grpc-ecosystem/grpc-gateway/master/protoc-gen-grpc-gateway/httprule/types_test.go
import (
"fmt"