1
0
mirror of https://github.com/go-micro/go-micro.git synced 2024-12-24 10:07:04 +02:00
go-micro/api/handler/rpc/stream.go

151 lines
2.9 KiB
Go

package rpc
import (
"context"
"encoding/json"
"net/http"
"strings"
"github.com/gorilla/websocket"
"github.com/micro/go-micro/v2/api"
"github.com/micro/go-micro/v2/client"
"github.com/micro/go-micro/v2/client/selector"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
// serveWebsocket will stream rpc back over websockets assuming json
func serveWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, service *api.Service, c client.Client) {
// upgrade the connection
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
// close on exit
defer conn.Close()
// wait for the first request so we know
_, p, err := conn.ReadMessage()
if err != nil {
return
}
// send to backend
// default to trying json
var request json.RawMessage
// if the extracted payload isn't empty lets use it
if len(p) > 0 {
request = json.RawMessage(p)
}
// create a request to the backend
req := c.NewRequest(
service.Name,
service.Endpoint.Name,
&request,
client.WithContentType("application/json"),
)
so := selector.WithStrategy(strategy(service.Services))
// create a new stream
stream, err := c.Stream(ctx, req, client.WithSelectOption(so))
if err != nil {
return
}
// send the first request for the client
// since
if err := stream.Send(request); err != nil {
return
}
go writeLoop(conn, stream)
resp := stream.Response()
// receive from stream and send to client
for {
// read backend response body
body, err := resp.Read()
if err != nil {
return
}
// write the response
if err := conn.WriteMessage(websocket.TextMessage, body); err != nil {
return
}
}
}
// writeLoop
func writeLoop(conn *websocket.Conn, stream client.Stream) {
// close stream when done
defer stream.Close()
for {
_, p, err := conn.ReadMessage()
if err != nil {
return
}
// send to backend
// default to trying json
var request json.RawMessage
// if the extracted payload isn't empty lets use it
if len(p) > 0 {
request = json.RawMessage(p)
}
if err := stream.Send(request); err != nil {
return
}
}
}
func isStream(r *http.Request, srv *api.Service) bool {
// check if it's a web socket
if !isWebSocket(r) {
return false
}
// check if the endpoint supports streaming
for _, service := range srv.Services {
for _, ep := range service.Endpoints {
// skip if it doesn't match the name
if ep.Name != srv.Endpoint.Name {
continue
}
// matched if the name
if v := ep.Metadata["stream"]; v == "true" {
return true
}
}
}
return false
}
func isWebSocket(r *http.Request) bool {
contains := func(key, val string) bool {
vv := strings.Split(r.Header.Get(key), ",")
for _, v := range vv {
if val == strings.ToLower(strings.TrimSpace(v)) {
return true
}
}
return false
}
if contains("Connection", "upgrade") && contains("Upgrade", "websocket") {
return true
}
return false
}