package jsonrpc2 import ( "encoding/json" "errors" "fmt" "io" "sync" "go-micro.dev/v4/codec" ) type serverCodec struct { encmutex sync.Mutex // protects enc dec *json.Decoder // for reading JSON values enc *json.Encoder // for writing JSON values c io.Closer // temporary work space req serverRequest // JSON-RPC clients can use arbitrary json values as request IDs. // Package rpc expects uint64 request IDs. // We assign uint64 sequence numbers to incoming requests // but save the original request ID in the pending map. // When rpc responds, we use the sequence number in // the response to find the original request ID. mutex sync.Mutex // protects seq, pending seq uint64 pending map[interface{}]*json.RawMessage } func newServerCodec(conn io.ReadWriteCloser) *serverCodec { return &serverCodec{ dec: json.NewDecoder(conn), enc: json.NewEncoder(conn), c: conn, pending: make(map[interface{}]*json.RawMessage), } } type serverRequest struct { Version string `json:"jsonrpc"` Method string `json:"method"` Params *json.RawMessage `json:"params"` ID *json.RawMessage `json:"id"` } func (r *serverRequest) reset() { r.Version = "" r.Method = "" r.Params = nil r.ID = nil } func (r *serverRequest) UnmarshalJSON(raw []byte) error { r.reset() type req *serverRequest if err := json.Unmarshal(raw, req(r)); err != nil { return errors.New("bad request") } var o = make(map[string]*json.RawMessage) if err := json.Unmarshal(raw, &o); err != nil { return errors.New("bad request") } if o["jsonrpc"] == nil || o["method"] == nil { return errors.New("bad request") } _, okID := o["id"] _, okParams := o["params"] if len(o) == 3 && !(okID || okParams) || len(o) == 4 && !(okID && okParams) || len(o) > 4 { return errors.New("bad request") } if r.Version != "2.0" { return errors.New("bad request") } if okParams { if r.Params == nil || len(*r.Params) == 0 { return errors.New("bad request") } switch []byte(*r.Params)[0] { case '[', '{': default: return errors.New("bad request") } } if okID && r.ID == nil { r.ID = &null } if okID { if len(*r.ID) == 0 { return errors.New("bad request") } switch []byte(*r.ID)[0] { case 't', 'f', '{', '[': return errors.New("bad request") } } return nil } type serverResponse struct { Version string `json:"jsonrpc"` ID *json.RawMessage `json:"id"` Result interface{} `json:"result,omitempty"` Error interface{} `json:"error,omitempty"` } func (c *serverCodec) ReadHeader(m *codec.Message) (err error) { // If return error: // - codec will be closed // So, try to send error reply to client before returning error. c.req.reset() var raw json.RawMessage if err := c.dec.Decode(&raw); err != nil { c.encmutex.Lock() c.enc.Encode(serverResponse{Version: "2.0", ID: &null, Error: errParse}) c.encmutex.Unlock() return err } if err := json.Unmarshal(raw, &c.req); err != nil { if err.Error() == "bad request" { c.encmutex.Lock() c.enc.Encode(serverResponse{Version: "2.0", ID: &null, Error: errRequest}) c.encmutex.Unlock() } return err } m.Endpoint = c.req.Method // JSON request id can be any JSON value; // RPC package expects uint64. Translate to // internal uint64 and save JSON on the side. c.mutex.Lock() c.seq++ c.pending[c.seq] = c.req.ID c.req.ID = nil m.Id = fmt.Sprintf("%d", c.seq) c.mutex.Unlock() return nil } func (c *serverCodec) ReadBody(x interface{}) error { // If x!=nil and return error e: // - WriteResponse() will be called with e.Error() in r.Error if x == nil { return nil } if c.req.Params == nil { return nil } if err := json.Unmarshal(*c.req.Params, x); err != nil { return NewError(errParams.Code, err.Error()) } return nil } var null = json.RawMessage([]byte("null")) func (c *serverCodec) Write(m *codec.Message, x interface{}) error { // If return error: nothing happens. // In r.Error will be "" or .Error() of error returned by: // - ReadRequestBody() // - called RPC method c.mutex.Lock() b, ok := c.pending[m.Id] if !ok { c.mutex.Unlock() fmt.Println("invalid sequence number in response", m.Id) return errors.New("invalid sequence number in response") } c.mutex.Unlock() if replies, ok := x.(*[]*json.RawMessage); m.Endpoint == "JSONRPC2.Batch" && ok { if len(*replies) == 0 { return nil } c.encmutex.Lock() defer c.encmutex.Unlock() return c.enc.Encode(replies) } if b == nil { // Notification. Do not respond. return nil } resp := serverResponse{Version: "2.0", ID: b} if m.Error == "" { if x == nil { resp.Result = &null } else { resp.Result = x } } else if m.Error[0] == '{' && m.Error[len(m.Error)-1] == '}' { // Well… this check for '{'…'}' isn't too strict, but I // suppose we're trusting our own RPC methods (this way they // can force sending wrong reply or many replies instead // of one) and normal errors won't be formatted this way. raw := json.RawMessage(m.Error) resp.Error = &raw } else { raw := json.RawMessage(newError(m.Error).Error()) resp.Error = &raw } c.encmutex.Lock() defer c.encmutex.Unlock() return c.enc.Encode(resp) } func (c *serverCodec) Close() error { return c.c.Close() }