1
0
mirror of https://github.com/open-telemetry/opentelemetry-go.git synced 2026-06-03 18:35:08 +02:00

upgrade thrift to v0.14.1 in jaeger exporter (#1712)

* upgrade thrift to v0.14.1 in jaeger exporter

* remove jaeger exporter vendor folder

* update PR number in changlog

Co-authored-by: Tyler Yahn <MrAlias@users.noreply.github.com>
This commit is contained in:
Kevin Schneider
2021-03-22 14:12:56 -05:00
committed by GitHub
parent 5a6a854d50
commit a9b2f85134
55 changed files with 7170 additions and 1892 deletions
+1 -2
View File
@@ -10,6 +10,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
### Changed
- Jaeger exporter was updated to use thrift v0.14.1. (#1712)
- Migrate from using internally built and maintained version of the OTLP to the one hosted at `go.opentelemetry.io/proto/otlp`. (#1713)
- Migrate from using `github.com/gogo/protobuf` to `google.golang.org/protobuf` to match `go.opentelemetry.io/proto/otlp`. (#1713)
@@ -38,8 +39,6 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
- Renamed the `LabelSet` method of `"go.opentelemetry.io/otel/sdk/resource".Resource` to `Set`. (#1692)
- Changed `WithSDK` to `WithSDKOptions` to accept variadic arguments of `TracerProviderOption` type in `go.opentelemetry.io/otel/exporters/trace/jaeger` package. (#1693)
- Changed `WithSDK` to `WithSDKOptions` to accept variadic arguments of `TracerProviderOption` type in `go.opentelemetry.io/otel/exporters/trace/zipkin` package. (#1693)
- `"go.opentelemetry.io/otel/sdk/resource".NewWithAttributes` will now drop any invalid attributes passed. (#1703)
- `"go.opentelemetry.io/otel/sdk/resource".StringDetector` will now error if the produced attribute is invalid. (#1703)
### Removed
+1 -1
View File
@@ -9,4 +9,4 @@ go get -u go.opentelemetry.io/otel/exporters/trace/jaeger
## Maintenance
This exporter uses a vendored copy of the Apache Thrift library (v0.13.0) at a custom import path. When re-generating Thrift code in future, please adapt import paths as necessary.
This exporter uses a vendored copy of the Apache Thrift library (v0.14.1) at a custom import path. When re-generating Thrift code in future, please adapt import paths as necessary.
+7 -6
View File
@@ -15,6 +15,7 @@
package jaeger // import "go.opentelemetry.io/otel/exporters/trace/jaeger"
import (
"context"
"fmt"
"io"
"log"
@@ -23,6 +24,7 @@ import (
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/third_party/thrift/lib/go/thrift"
genAgent "go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/agent"
gen "go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/jaeger"
)
@@ -31,11 +33,11 @@ const udpPacketMaxLength = 65000
// agentClientUDP is a UDP client to Jaeger agent that implements gen.Agent interface.
type agentClientUDP struct {
gen.Agent
genAgent.Agent
io.Closer
connUDP udpConn
client *gen.AgentClient
client *genAgent.AgentClient
maxPacketSize int // max size of datagram in bytes
thriftBuffer *thrift.TMemoryBuffer // buffer used to calculate byte size of a span
}
@@ -70,8 +72,8 @@ func newAgentClientUDP(params agentClientUDPParams) (*agentClientUDP, error) {
}
thriftBuffer := thrift.NewTMemoryBufferLen(params.MaxPacketSize)
protocolFactory := thrift.NewTCompactProtocolFactory()
client := gen.NewAgentClientFactory(thriftBuffer, protocolFactory)
protocolFactory := thrift.NewTCompactProtocolFactoryConf(&thrift.TConfiguration{})
client := genAgent.NewAgentClientFactory(thriftBuffer, protocolFactory)
var connUDP udpConn
var err error
@@ -109,8 +111,7 @@ func newAgentClientUDP(params agentClientUDPParams) (*agentClientUDP, error) {
// EmitBatch implements EmitBatch() of Agent interface
func (a *agentClientUDP) EmitBatch(batch *gen.Batch) error {
a.thriftBuffer.Reset()
a.client.SeqId = 0 // we have no need for distinct SeqIds for our one-way UDP messages
if err := a.client.EmitBatch(batch); err != nil {
if err := a.client.EmitBatch(context.Background(), batch); err != nil {
return err
}
if a.thriftBuffer.Len() > a.maxPacketSize {
@@ -0,0 +1,6 @@
// Code generated by Thrift Compiler (0.14.1). DO NOT EDIT.
package agent
var GoUnusedProtection__ int;
@@ -0,0 +1,26 @@
// Code generated by Thrift Compiler (0.14.1). DO NOT EDIT.
package agent
import (
"bytes"
"context"
"fmt"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/jaeger"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/zipkincore"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/third_party/thrift/lib/go/thrift"
"time"
)
// (needed to ensure safety because of naive import list construction.)
var _ = thrift.ZERO
var _ = fmt.Printf
var _ = context.Background
var _ = time.Now
var _ = bytes.Equal
var _ = jaeger.GoUnusedProtection__
var _ = zipkincore.GoUnusedProtection__
func init() {
}
@@ -0,0 +1,209 @@
// Code generated by Thrift Compiler (0.14.1). DO NOT EDIT.
package main
import (
"context"
"flag"
"fmt"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/agent"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/jaeger"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/zipkincore"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/third_party/thrift/lib/go/thrift"
"math"
"net"
"net/url"
"os"
"strconv"
"strings"
)
var _ = jaeger.GoUnusedProtection__
var _ = zipkincore.GoUnusedProtection__
var _ = agent.GoUnusedProtection__
func Usage() {
fmt.Fprintln(os.Stderr, "Usage of ", os.Args[0], " [-h host:port] [-u url] [-f[ramed]] function [arg1 [arg2...]]:")
flag.PrintDefaults()
fmt.Fprintln(os.Stderr, "\nFunctions:")
fmt.Fprintln(os.Stderr, " void emitZipkinBatch( spans)")
fmt.Fprintln(os.Stderr, " void emitBatch(Batch batch)")
fmt.Fprintln(os.Stderr)
os.Exit(0)
}
type httpHeaders map[string]string
func (h httpHeaders) String() string {
var m map[string]string = h
return fmt.Sprintf("%s", m)
}
func (h httpHeaders) Set(value string) error {
parts := strings.Split(value, ": ")
if len(parts) != 2 {
return fmt.Errorf("header should be of format 'Key: Value'")
}
h[parts[0]] = parts[1]
return nil
}
func main() {
flag.Usage = Usage
var host string
var port int
var protocol string
var urlString string
var framed bool
var useHttp bool
headers := make(httpHeaders)
var parsedUrl *url.URL
var trans thrift.TTransport
_ = strconv.Atoi
_ = math.Abs
flag.Usage = Usage
flag.StringVar(&host, "h", "localhost", "Specify host and port")
flag.IntVar(&port, "p", 9090, "Specify port")
flag.StringVar(&protocol, "P", "binary", "Specify the protocol (binary, compact, simplejson, json)")
flag.StringVar(&urlString, "u", "", "Specify the url")
flag.BoolVar(&framed, "framed", false, "Use framed transport")
flag.BoolVar(&useHttp, "http", false, "Use http")
flag.Var(headers, "H", "Headers to set on the http(s) request (e.g. -H \"Key: Value\")")
flag.Parse()
if len(urlString) > 0 {
var err error
parsedUrl, err = url.Parse(urlString)
if err != nil {
fmt.Fprintln(os.Stderr, "Error parsing URL: ", err)
flag.Usage()
}
host = parsedUrl.Host
useHttp = len(parsedUrl.Scheme) <= 0 || parsedUrl.Scheme == "http" || parsedUrl.Scheme == "https"
} else if useHttp {
_, err := url.Parse(fmt.Sprint("http://", host, ":", port))
if err != nil {
fmt.Fprintln(os.Stderr, "Error parsing URL: ", err)
flag.Usage()
}
}
cmd := flag.Arg(0)
var err error
if useHttp {
trans, err = thrift.NewTHttpClient(parsedUrl.String())
if len(headers) > 0 {
httptrans := trans.(*thrift.THttpClient)
for key, value := range headers {
httptrans.SetHeader(key, value)
}
}
} else {
portStr := fmt.Sprint(port)
if strings.Contains(host, ":") {
host, portStr, err = net.SplitHostPort(host)
if err != nil {
fmt.Fprintln(os.Stderr, "error with host:", err)
os.Exit(1)
}
}
trans, err = thrift.NewTSocket(net.JoinHostPort(host, portStr))
if err != nil {
fmt.Fprintln(os.Stderr, "error resolving address:", err)
os.Exit(1)
}
if framed {
trans = thrift.NewTFramedTransport(trans)
}
}
if err != nil {
fmt.Fprintln(os.Stderr, "Error creating transport", err)
os.Exit(1)
}
defer trans.Close()
var protocolFactory thrift.TProtocolFactory
switch protocol {
case "compact":
protocolFactory = thrift.NewTCompactProtocolFactory()
break
case "simplejson":
protocolFactory = thrift.NewTSimpleJSONProtocolFactory()
break
case "json":
protocolFactory = thrift.NewTJSONProtocolFactory()
break
case "binary", "":
protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
break
default:
fmt.Fprintln(os.Stderr, "Invalid protocol specified: ", protocol)
Usage()
os.Exit(1)
}
iprot := protocolFactory.GetProtocol(trans)
oprot := protocolFactory.GetProtocol(trans)
client := agent.NewAgentClient(thrift.NewTStandardClient(iprot, oprot))
if err := trans.Open(); err != nil {
fmt.Fprintln(os.Stderr, "Error opening socket to ", host, ":", port, " ", err)
os.Exit(1)
}
switch cmd {
case "emitZipkinBatch":
if flag.NArg()-1 != 1 {
fmt.Fprintln(os.Stderr, "EmitZipkinBatch requires 1 args")
flag.Usage()
}
arg5 := flag.Arg(1)
mbTrans6 := thrift.NewTMemoryBufferLen(len(arg5))
defer mbTrans6.Close()
_, err7 := mbTrans6.WriteString(arg5)
if err7 != nil {
Usage()
return
}
factory8 := thrift.NewTJSONProtocolFactory()
jsProt9 := factory8.GetProtocol(mbTrans6)
containerStruct0 := agent.NewAgentEmitZipkinBatchArgs()
err10 := containerStruct0.ReadField1(context.Background(), jsProt9)
if err10 != nil {
Usage()
return
}
argvalue0 := containerStruct0.Spans
value0 := argvalue0
fmt.Print(client.EmitZipkinBatch(context.Background(), value0))
fmt.Print("\n")
break
case "emitBatch":
if flag.NArg()-1 != 1 {
fmt.Fprintln(os.Stderr, "EmitBatch requires 1 args")
flag.Usage()
}
arg11 := flag.Arg(1)
mbTrans12 := thrift.NewTMemoryBufferLen(len(arg11))
defer mbTrans12.Close()
_, err13 := mbTrans12.WriteString(arg11)
if err13 != nil {
Usage()
return
}
factory14 := thrift.NewTJSONProtocolFactory()
jsProt15 := factory14.GetProtocol(mbTrans12)
argvalue0 := jaeger.NewBatch()
err16 := argvalue0.Read(context.Background(), jsProt15)
if err16 != nil {
Usage()
return
}
value0 := argvalue0
fmt.Print(client.EmitBatch(context.Background(), value0))
fmt.Print("\n")
break
case "":
Usage()
break
default:
fmt.Fprintln(os.Stderr, "Invalid function ", cmd)
}
}
@@ -0,0 +1,411 @@
// Code generated by Thrift Compiler (0.14.1). DO NOT EDIT.
package agent
import (
"bytes"
"context"
"fmt"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/jaeger"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/zipkincore"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/third_party/thrift/lib/go/thrift"
"time"
)
// (needed to ensure safety because of naive import list construction.)
var _ = thrift.ZERO
var _ = fmt.Printf
var _ = context.Background
var _ = time.Now
var _ = bytes.Equal
var _ = jaeger.GoUnusedProtection__
var _ = zipkincore.GoUnusedProtection__
type Agent interface {
// Parameters:
// - Spans
EmitZipkinBatch(ctx context.Context, spans []*zipkincore.Span) (_err error)
// Parameters:
// - Batch
EmitBatch(ctx context.Context, batch *jaeger.Batch) (_err error)
}
type AgentClient struct {
c thrift.TClient
meta thrift.ResponseMeta
}
func NewAgentClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *AgentClient {
return &AgentClient{
c: thrift.NewTStandardClient(f.GetProtocol(t), f.GetProtocol(t)),
}
}
func NewAgentClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *AgentClient {
return &AgentClient{
c: thrift.NewTStandardClient(iprot, oprot),
}
}
func NewAgentClient(c thrift.TClient) *AgentClient {
return &AgentClient{
c: c,
}
}
func (p *AgentClient) Client_() thrift.TClient {
return p.c
}
func (p *AgentClient) LastResponseMeta_() thrift.ResponseMeta {
return p.meta
}
func (p *AgentClient) SetLastResponseMeta_(meta thrift.ResponseMeta) {
p.meta = meta
}
// Parameters:
// - Spans
func (p *AgentClient) EmitZipkinBatch(ctx context.Context, spans []*zipkincore.Span) (_err error) {
var _args0 AgentEmitZipkinBatchArgs
_args0.Spans = spans
p.SetLastResponseMeta_(thrift.ResponseMeta{})
if _, err := p.Client_().Call(ctx, "emitZipkinBatch", &_args0, nil); err != nil {
return err
}
return nil
}
// Parameters:
// - Batch
func (p *AgentClient) EmitBatch(ctx context.Context, batch *jaeger.Batch) (_err error) {
var _args1 AgentEmitBatchArgs
_args1.Batch = batch
p.SetLastResponseMeta_(thrift.ResponseMeta{})
if _, err := p.Client_().Call(ctx, "emitBatch", &_args1, nil); err != nil {
return err
}
return nil
}
type AgentProcessor struct {
processorMap map[string]thrift.TProcessorFunction
handler Agent
}
func (p *AgentProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) {
p.processorMap[key] = processor
}
func (p *AgentProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) {
processor, ok = p.processorMap[key]
return processor, ok
}
func (p *AgentProcessor) ProcessorMap() map[string]thrift.TProcessorFunction {
return p.processorMap
}
func NewAgentProcessor(handler Agent) *AgentProcessor {
self2 := &AgentProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)}
self2.processorMap["emitZipkinBatch"] = &agentProcessorEmitZipkinBatch{handler: handler}
self2.processorMap["emitBatch"] = &agentProcessorEmitBatch{handler: handler}
return self2
}
func (p *AgentProcessor) Process(ctx context.Context, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) {
name, _, seqId, err2 := iprot.ReadMessageBegin(ctx)
if err2 != nil {
return false, thrift.WrapTException(err2)
}
if processor, ok := p.GetProcessorFunction(name); ok {
return processor.Process(ctx, seqId, iprot, oprot)
}
iprot.Skip(ctx, thrift.STRUCT)
iprot.ReadMessageEnd(ctx)
x3 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name)
oprot.WriteMessageBegin(ctx, name, thrift.EXCEPTION, seqId)
x3.Write(ctx, oprot)
oprot.WriteMessageEnd(ctx)
oprot.Flush(ctx)
return false, x3
}
type agentProcessorEmitZipkinBatch struct {
handler Agent
}
func (p *agentProcessorEmitZipkinBatch) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) {
args := AgentEmitZipkinBatchArgs{}
var err2 error
if err2 = args.Read(ctx, iprot); err2 != nil {
iprot.ReadMessageEnd(ctx)
return false, thrift.WrapTException(err2)
}
iprot.ReadMessageEnd(ctx)
tickerCancel := func() {}
_ = tickerCancel
if err2 = p.handler.EmitZipkinBatch(ctx, args.Spans); err2 != nil {
tickerCancel()
return true, thrift.WrapTException(err2)
}
tickerCancel()
return true, nil
}
type agentProcessorEmitBatch struct {
handler Agent
}
func (p *agentProcessorEmitBatch) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) {
args := AgentEmitBatchArgs{}
var err2 error
if err2 = args.Read(ctx, iprot); err2 != nil {
iprot.ReadMessageEnd(ctx)
return false, thrift.WrapTException(err2)
}
iprot.ReadMessageEnd(ctx)
tickerCancel := func() {}
_ = tickerCancel
if err2 = p.handler.EmitBatch(ctx, args.Batch); err2 != nil {
tickerCancel()
return true, thrift.WrapTException(err2)
}
tickerCancel()
return true, nil
}
// HELPER FUNCTIONS AND STRUCTURES
// Attributes:
// - Spans
type AgentEmitZipkinBatchArgs struct {
Spans []*zipkincore.Span `thrift:"spans,1" db:"spans" json:"spans"`
}
func NewAgentEmitZipkinBatchArgs() *AgentEmitZipkinBatchArgs {
return &AgentEmitZipkinBatchArgs{}
}
func (p *AgentEmitZipkinBatchArgs) GetSpans() []*zipkincore.Span {
return p.Spans
}
func (p *AgentEmitZipkinBatchArgs) Read(ctx context.Context, iprot thrift.TProtocol) error {
if _, err := iprot.ReadStructBegin(ctx); err != nil {
return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err)
}
for {
_, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx)
if err != nil {
return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err)
}
if fieldTypeId == thrift.STOP {
break
}
switch fieldId {
case 1:
if fieldTypeId == thrift.LIST {
if err := p.ReadField1(ctx, iprot); err != nil {
return err
}
} else {
if err := iprot.Skip(ctx, fieldTypeId); err != nil {
return err
}
}
default:
if err := iprot.Skip(ctx, fieldTypeId); err != nil {
return err
}
}
if err := iprot.ReadFieldEnd(ctx); err != nil {
return err
}
}
if err := iprot.ReadStructEnd(ctx); err != nil {
return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err)
}
return nil
}
func (p *AgentEmitZipkinBatchArgs) ReadField1(ctx context.Context, iprot thrift.TProtocol) error {
_, size, err := iprot.ReadListBegin(ctx)
if err != nil {
return thrift.PrependError("error reading list begin: ", err)
}
tSlice := make([]*zipkincore.Span, 0, size)
p.Spans = tSlice
for i := 0; i < size; i++ {
_elem4 := &zipkincore.Span{}
if err := _elem4.Read(ctx, iprot); err != nil {
return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", _elem4), err)
}
p.Spans = append(p.Spans, _elem4)
}
if err := iprot.ReadListEnd(ctx); err != nil {
return thrift.PrependError("error reading list end: ", err)
}
return nil
}
func (p *AgentEmitZipkinBatchArgs) Write(ctx context.Context, oprot thrift.TProtocol) error {
if err := oprot.WriteStructBegin(ctx, "emitZipkinBatch_args"); err != nil {
return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err)
}
if p != nil {
if err := p.writeField1(ctx, oprot); err != nil {
return err
}
}
if err := oprot.WriteFieldStop(ctx); err != nil {
return thrift.PrependError("write field stop error: ", err)
}
if err := oprot.WriteStructEnd(ctx); err != nil {
return thrift.PrependError("write struct stop error: ", err)
}
return nil
}
func (p *AgentEmitZipkinBatchArgs) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) {
if err := oprot.WriteFieldBegin(ctx, "spans", thrift.LIST, 1); err != nil {
return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:spans: ", p), err)
}
if err := oprot.WriteListBegin(ctx, thrift.STRUCT, len(p.Spans)); err != nil {
return thrift.PrependError("error writing list begin: ", err)
}
for _, v := range p.Spans {
if err := v.Write(ctx, oprot); err != nil {
return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", v), err)
}
}
if err := oprot.WriteListEnd(ctx); err != nil {
return thrift.PrependError("error writing list end: ", err)
}
if err := oprot.WriteFieldEnd(ctx); err != nil {
return thrift.PrependError(fmt.Sprintf("%T write field end error 1:spans: ", p), err)
}
return err
}
func (p *AgentEmitZipkinBatchArgs) String() string {
if p == nil {
return "<nil>"
}
return fmt.Sprintf("AgentEmitZipkinBatchArgs(%+v)", *p)
}
// Attributes:
// - Batch
type AgentEmitBatchArgs struct {
Batch *jaeger.Batch `thrift:"batch,1" db:"batch" json:"batch"`
}
func NewAgentEmitBatchArgs() *AgentEmitBatchArgs {
return &AgentEmitBatchArgs{}
}
var AgentEmitBatchArgs_Batch_DEFAULT *jaeger.Batch
func (p *AgentEmitBatchArgs) GetBatch() *jaeger.Batch {
if !p.IsSetBatch() {
return AgentEmitBatchArgs_Batch_DEFAULT
}
return p.Batch
}
func (p *AgentEmitBatchArgs) IsSetBatch() bool {
return p.Batch != nil
}
func (p *AgentEmitBatchArgs) Read(ctx context.Context, iprot thrift.TProtocol) error {
if _, err := iprot.ReadStructBegin(ctx); err != nil {
return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err)
}
for {
_, fieldTypeId, fieldId, err := iprot.ReadFieldBegin(ctx)
if err != nil {
return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err)
}
if fieldTypeId == thrift.STOP {
break
}
switch fieldId {
case 1:
if fieldTypeId == thrift.STRUCT {
if err := p.ReadField1(ctx, iprot); err != nil {
return err
}
} else {
if err := iprot.Skip(ctx, fieldTypeId); err != nil {
return err
}
}
default:
if err := iprot.Skip(ctx, fieldTypeId); err != nil {
return err
}
}
if err := iprot.ReadFieldEnd(ctx); err != nil {
return err
}
}
if err := iprot.ReadStructEnd(ctx); err != nil {
return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err)
}
return nil
}
func (p *AgentEmitBatchArgs) ReadField1(ctx context.Context, iprot thrift.TProtocol) error {
p.Batch = &jaeger.Batch{}
if err := p.Batch.Read(ctx, iprot); err != nil {
return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Batch), err)
}
return nil
}
func (p *AgentEmitBatchArgs) Write(ctx context.Context, oprot thrift.TProtocol) error {
if err := oprot.WriteStructBegin(ctx, "emitBatch_args"); err != nil {
return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err)
}
if p != nil {
if err := p.writeField1(ctx, oprot); err != nil {
return err
}
}
if err := oprot.WriteFieldStop(ctx); err != nil {
return thrift.PrependError("write field stop error: ", err)
}
if err := oprot.WriteStructEnd(ctx); err != nil {
return thrift.PrependError("write struct stop error: ", err)
}
return nil
}
func (p *AgentEmitBatchArgs) writeField1(ctx context.Context, oprot thrift.TProtocol) (err error) {
if err := oprot.WriteFieldBegin(ctx, "batch", thrift.STRUCT, 1); err != nil {
return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:batch: ", p), err)
}
if err := p.Batch.Write(ctx, oprot); err != nil {
return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Batch), err)
}
if err := oprot.WriteFieldEnd(ctx); err != nil {
return thrift.PrependError(fmt.Sprintf("%T write field end error 1:batch: ", p), err)
}
return err
}
func (p *AgentEmitBatchArgs) String() string {
if p == nil {
return "<nil>"
}
return fmt.Sprintf("AgentEmitBatchArgs(%+v)", *p)
}
@@ -1,6 +1,6 @@
// Autogenerated by Thrift Compiler (0.11.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
// Code generated by Thrift Compiler (0.14.1). DO NOT EDIT.
package jaeger
var GoUnusedProtection__ int
var GoUnusedProtection__ int;
@@ -1,244 +0,0 @@
// Autogenerated by Thrift Compiler (0.9.3)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
package jaeger
import (
"bytes"
"context"
"fmt"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/third_party/thrift/lib/go/thrift"
)
// (needed to ensure safety because of naive import list construction.)
var _ = thrift.ZERO
var _ = fmt.Printf
var _ = bytes.Equal
type Agent interface {
// Parameters:
// - Batch
EmitBatch(batch *Batch) (err error)
}
type AgentClient struct {
Transport thrift.TTransport
ProtocolFactory thrift.TProtocolFactory
InputProtocol thrift.TProtocol
OutputProtocol thrift.TProtocol
SeqId int32
}
func NewAgentClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *AgentClient {
return &AgentClient{Transport: t,
ProtocolFactory: f,
InputProtocol: f.GetProtocol(t),
OutputProtocol: f.GetProtocol(t),
SeqId: 0,
}
}
func NewAgentClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *AgentClient {
return &AgentClient{Transport: t,
ProtocolFactory: nil,
InputProtocol: iprot,
OutputProtocol: oprot,
SeqId: 0,
}
}
// Parameters:
// - Batch
func (p *AgentClient) EmitBatch(batch *Batch) (err error) {
if err = p.sendEmitBatch(batch); err != nil {
return
}
return
}
func (p *AgentClient) sendEmitBatch(batch *Batch) (err error) {
oprot := p.OutputProtocol
if oprot == nil {
oprot = p.ProtocolFactory.GetProtocol(p.Transport)
p.OutputProtocol = oprot
}
p.SeqId++
if err = oprot.WriteMessageBegin("emitBatch", thrift.ONEWAY, p.SeqId); err != nil {
return
}
args := AgentEmitBatchArgs{
Batch: batch,
}
if err = args.Write(oprot); err != nil {
return
}
if err = oprot.WriteMessageEnd(); err != nil {
return
}
return oprot.Flush(context.Background())
}
type AgentProcessor struct {
processorMap map[string]thrift.TProcessorFunction
handler Agent
}
func (p *AgentProcessor) AddToProcessorMap(key string, processor thrift.TProcessorFunction) {
p.processorMap[key] = processor
}
func (p *AgentProcessor) GetProcessorFunction(key string) (processor thrift.TProcessorFunction, ok bool) {
processor, ok = p.processorMap[key]
return processor, ok
}
func (p *AgentProcessor) ProcessorMap() map[string]thrift.TProcessorFunction {
return p.processorMap
}
func NewAgentProcessor(handler Agent) *AgentProcessor {
self0 := &AgentProcessor{handler: handler, processorMap: make(map[string]thrift.TProcessorFunction)}
self0.processorMap["emitBatch"] = &agentProcessorEmitBatch{handler: handler}
return self0
}
func (p *AgentProcessor) Process(iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) {
ctx := context.Background()
name, _, seqId, err := iprot.ReadMessageBegin()
if err != nil {
return false, err
}
if processor, ok := p.GetProcessorFunction(name); ok {
return processor.Process(ctx, seqId, iprot, oprot)
}
iprot.Skip(thrift.STRUCT)
iprot.ReadMessageEnd()
x1 := thrift.NewTApplicationException(thrift.UNKNOWN_METHOD, "Unknown function "+name)
oprot.WriteMessageBegin(name, thrift.EXCEPTION, seqId)
x1.Write(oprot)
oprot.WriteMessageEnd()
oprot.Flush(ctx)
return false, x1
}
type agentProcessorEmitBatch struct {
handler Agent
}
func (p *agentProcessorEmitBatch) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) {
args := AgentEmitBatchArgs{}
if err = args.Read(iprot); err != nil {
iprot.ReadMessageEnd()
return false, err
}
iprot.ReadMessageEnd()
var err2 error
if err2 = p.handler.EmitBatch(args.Batch); err2 != nil {
return true, err2
}
return true, nil
}
// HELPER FUNCTIONS AND STRUCTURES
// Attributes:
// - Batch
type AgentEmitBatchArgs struct {
Batch *Batch `thrift:"batch,1" json:"batch"`
}
func NewAgentEmitBatchArgs() *AgentEmitBatchArgs {
return &AgentEmitBatchArgs{}
}
var AgentEmitBatchArgs_Batch_DEFAULT *Batch
func (p *AgentEmitBatchArgs) GetBatch() *Batch {
if !p.IsSetBatch() {
return AgentEmitBatchArgs_Batch_DEFAULT
}
return p.Batch
}
func (p *AgentEmitBatchArgs) IsSetBatch() bool {
return p.Batch != nil
}
func (p *AgentEmitBatchArgs) Read(iprot thrift.TProtocol) error {
if _, err := iprot.ReadStructBegin(); err != nil {
return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err)
}
for {
_, fieldTypeId, fieldId, err := iprot.ReadFieldBegin()
if err != nil {
return thrift.PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err)
}
if fieldTypeId == thrift.STOP {
break
}
switch fieldId {
case 1:
if err := p.readField1(iprot); err != nil {
return err
}
default:
if err := iprot.Skip(fieldTypeId); err != nil {
return err
}
}
if err := iprot.ReadFieldEnd(); err != nil {
return err
}
}
if err := iprot.ReadStructEnd(); err != nil {
return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err)
}
return nil
}
func (p *AgentEmitBatchArgs) readField1(iprot thrift.TProtocol) error {
p.Batch = &Batch{}
if err := p.Batch.Read(iprot); err != nil {
return thrift.PrependError(fmt.Sprintf("%T error reading struct: ", p.Batch), err)
}
return nil
}
func (p *AgentEmitBatchArgs) Write(oprot thrift.TProtocol) error {
if err := oprot.WriteStructBegin("emitBatch_args"); err != nil {
return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err)
}
if err := p.writeField1(oprot); err != nil {
return err
}
if err := oprot.WriteFieldStop(); err != nil {
return thrift.PrependError("write field stop error: ", err)
}
if err := oprot.WriteStructEnd(); err != nil {
return thrift.PrependError("write struct stop error: ", err)
}
return nil
}
func (p *AgentEmitBatchArgs) writeField1(oprot thrift.TProtocol) (err error) {
if err := oprot.WriteFieldBegin("batch", thrift.STRUCT, 1); err != nil {
return thrift.PrependError(fmt.Sprintf("%T write field begin error 1:batch: ", p), err)
}
if err := p.Batch.Write(oprot); err != nil {
return thrift.PrependError(fmt.Sprintf("%T error writing struct: ", p.Batch), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return thrift.PrependError(fmt.Sprintf("%T write field end error 1:batch: ", p), err)
}
return err
}
func (p *AgentEmitBatchArgs) String() string {
if p == nil {
return "<nil>"
}
return fmt.Sprintf("AgentEmitBatchArgs(%+v)", *p)
}
@@ -1,5 +1,4 @@
// Autogenerated by Thrift Compiler (0.11.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
// Code generated by Thrift Compiler (0.14.1). DO NOT EDIT.
package main
@@ -7,18 +6,18 @@ import (
"context"
"flag"
"fmt"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/jaeger"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/third_party/thrift/lib/go/thrift"
"math"
"net"
"net/url"
"os"
"strconv"
"strings"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/third_party/thrift/lib/go/thrift"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/jaeger"
)
var _ = jaeger.GoUnusedProtection__
func Usage() {
fmt.Fprintln(os.Stderr, "Usage of ", os.Args[0], " [-h host:port] [-u url] [-f[ramed]] function [arg1 [arg2...]]:")
flag.PrintDefaults()
@@ -28,6 +27,22 @@ func Usage() {
os.Exit(0)
}
type httpHeaders map[string]string
func (h httpHeaders) String() string {
var m map[string]string = h
return fmt.Sprintf("%s", m)
}
func (h httpHeaders) Set(value string) error {
parts := strings.Split(value, ": ")
if len(parts) != 2 {
return fmt.Errorf("header should be of format 'Key: Value'")
}
h[parts[0]] = parts[1]
return nil
}
func main() {
flag.Usage = Usage
var host string
@@ -35,8 +50,9 @@ func main() {
var protocol string
var urlString string
var framed bool
var useHTTP bool
var parsedURL *url.URL
var useHttp bool
headers := make(httpHeaders)
var parsedUrl *url.URL
var trans thrift.TTransport
_ = strconv.Atoi
_ = math.Abs
@@ -46,19 +62,20 @@ func main() {
flag.StringVar(&protocol, "P", "binary", "Specify the protocol (binary, compact, simplejson, json)")
flag.StringVar(&urlString, "u", "", "Specify the url")
flag.BoolVar(&framed, "framed", false, "Use framed transport")
flag.BoolVar(&useHTTP, "http", false, "Use http")
flag.BoolVar(&useHttp, "http", false, "Use http")
flag.Var(headers, "H", "Headers to set on the http(s) request (e.g. -H \"Key: Value\")")
flag.Parse()
if len(urlString) > 0 {
var err error
parsedURL, err = url.Parse(urlString)
parsedUrl, err = url.Parse(urlString)
if err != nil {
fmt.Fprintln(os.Stderr, "Error parsing URL: ", err)
flag.Usage()
}
host = parsedURL.Host
useHTTP = len(parsedURL.Scheme) <= 0 || parsedURL.Scheme == "http"
} else if useHTTP {
host = parsedUrl.Host
useHttp = len(parsedUrl.Scheme) <= 0 || parsedUrl.Scheme == "http" || parsedUrl.Scheme == "https"
} else if useHttp {
_, err := url.Parse(fmt.Sprint("http://", host, ":", port))
if err != nil {
fmt.Fprintln(os.Stderr, "Error parsing URL: ", err)
@@ -68,8 +85,14 @@ func main() {
cmd := flag.Arg(0)
var err error
if useHTTP {
trans, err = thrift.NewTHttpClient(parsedURL.String())
if useHttp {
trans, err = thrift.NewTHttpClient(parsedUrl.String())
if len(headers) > 0 {
httptrans := trans.(*thrift.THttpClient)
for key, value := range headers {
httptrans.SetHeader(key, value)
}
}
} else {
portStr := fmt.Sprint(port)
if strings.Contains(host, ":") {
@@ -126,19 +149,19 @@ func main() {
fmt.Fprintln(os.Stderr, "SubmitBatches requires 1 args")
flag.Usage()
}
arg12 := flag.Arg(1)
mbTrans13 := thrift.NewTMemoryBufferLen(len(arg12))
defer mbTrans13.Close()
_, err14 := mbTrans13.WriteString(arg12)
if err14 != nil {
arg19 := flag.Arg(1)
mbTrans20 := thrift.NewTMemoryBufferLen(len(arg19))
defer mbTrans20.Close()
_, err21 := mbTrans20.WriteString(arg19)
if err21 != nil {
Usage()
return
}
factory15 := thrift.NewTSimpleJSONProtocolFactory()
jsProt16 := factory15.GetProtocol(mbTrans13)
factory22 := thrift.NewTJSONProtocolFactory()
jsProt23 := factory22.GetProtocol(mbTrans20)
containerStruct0 := jaeger.NewCollectorSubmitBatchesArgs()
err17 := containerStruct0.ReadField1(jsProt16)
if err17 != nil {
err24 := containerStruct0.ReadField1(context.Background(), jsProt23)
if err24 != nil {
Usage()
return
}
@@ -1,5 +1,4 @@
// Autogenerated by Thrift Compiler (0.11.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
// Code generated by Thrift Compiler (0.14.1). DO NOT EDIT.
package jaeger
@@ -7,16 +6,15 @@ import (
"bytes"
"context"
"fmt"
"reflect"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/third_party/thrift/lib/go/thrift"
"time"
)
// (needed to ensure safety because of naive import list construction.)
var _ = thrift.ZERO
var _ = fmt.Printf
var _ = context.Background
var _ = reflect.DeepEqual
var _ = time.Now
var _ = bytes.Equal
func init() {
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,6 @@
// Code generated by Thrift Compiler (0.14.1). DO NOT EDIT.
package zipkincore
var GoUnusedProtection__ int;
@@ -0,0 +1,179 @@
// Code generated by Thrift Compiler (0.14.1). DO NOT EDIT.
package main
import (
"context"
"flag"
"fmt"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/gen-go/zipkincore"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/third_party/thrift/lib/go/thrift"
"math"
"net"
"net/url"
"os"
"strconv"
"strings"
)
var _ = zipkincore.GoUnusedProtection__
func Usage() {
fmt.Fprintln(os.Stderr, "Usage of ", os.Args[0], " [-h host:port] [-u url] [-f[ramed]] function [arg1 [arg2...]]:")
flag.PrintDefaults()
fmt.Fprintln(os.Stderr, "\nFunctions:")
fmt.Fprintln(os.Stderr, " submitZipkinBatch( spans)")
fmt.Fprintln(os.Stderr)
os.Exit(0)
}
type httpHeaders map[string]string
func (h httpHeaders) String() string {
var m map[string]string = h
return fmt.Sprintf("%s", m)
}
func (h httpHeaders) Set(value string) error {
parts := strings.Split(value, ": ")
if len(parts) != 2 {
return fmt.Errorf("header should be of format 'Key: Value'")
}
h[parts[0]] = parts[1]
return nil
}
func main() {
flag.Usage = Usage
var host string
var port int
var protocol string
var urlString string
var framed bool
var useHttp bool
headers := make(httpHeaders)
var parsedUrl *url.URL
var trans thrift.TTransport
_ = strconv.Atoi
_ = math.Abs
flag.Usage = Usage
flag.StringVar(&host, "h", "localhost", "Specify host and port")
flag.IntVar(&port, "p", 9090, "Specify port")
flag.StringVar(&protocol, "P", "binary", "Specify the protocol (binary, compact, simplejson, json)")
flag.StringVar(&urlString, "u", "", "Specify the url")
flag.BoolVar(&framed, "framed", false, "Use framed transport")
flag.BoolVar(&useHttp, "http", false, "Use http")
flag.Var(headers, "H", "Headers to set on the http(s) request (e.g. -H \"Key: Value\")")
flag.Parse()
if len(urlString) > 0 {
var err error
parsedUrl, err = url.Parse(urlString)
if err != nil {
fmt.Fprintln(os.Stderr, "Error parsing URL: ", err)
flag.Usage()
}
host = parsedUrl.Host
useHttp = len(parsedUrl.Scheme) <= 0 || parsedUrl.Scheme == "http" || parsedUrl.Scheme == "https"
} else if useHttp {
_, err := url.Parse(fmt.Sprint("http://", host, ":", port))
if err != nil {
fmt.Fprintln(os.Stderr, "Error parsing URL: ", err)
flag.Usage()
}
}
cmd := flag.Arg(0)
var err error
if useHttp {
trans, err = thrift.NewTHttpClient(parsedUrl.String())
if len(headers) > 0 {
httptrans := trans.(*thrift.THttpClient)
for key, value := range headers {
httptrans.SetHeader(key, value)
}
}
} else {
portStr := fmt.Sprint(port)
if strings.Contains(host, ":") {
host, portStr, err = net.SplitHostPort(host)
if err != nil {
fmt.Fprintln(os.Stderr, "error with host:", err)
os.Exit(1)
}
}
trans, err = thrift.NewTSocket(net.JoinHostPort(host, portStr))
if err != nil {
fmt.Fprintln(os.Stderr, "error resolving address:", err)
os.Exit(1)
}
if framed {
trans = thrift.NewTFramedTransport(trans)
}
}
if err != nil {
fmt.Fprintln(os.Stderr, "Error creating transport", err)
os.Exit(1)
}
defer trans.Close()
var protocolFactory thrift.TProtocolFactory
switch protocol {
case "compact":
protocolFactory = thrift.NewTCompactProtocolFactory()
break
case "simplejson":
protocolFactory = thrift.NewTSimpleJSONProtocolFactory()
break
case "json":
protocolFactory = thrift.NewTJSONProtocolFactory()
break
case "binary", "":
protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
break
default:
fmt.Fprintln(os.Stderr, "Invalid protocol specified: ", protocol)
Usage()
os.Exit(1)
}
iprot := protocolFactory.GetProtocol(trans)
oprot := protocolFactory.GetProtocol(trans)
client := zipkincore.NewZipkinCollectorClient(thrift.NewTStandardClient(iprot, oprot))
if err := trans.Open(); err != nil {
fmt.Fprintln(os.Stderr, "Error opening socket to ", host, ":", port, " ", err)
os.Exit(1)
}
switch cmd {
case "submitZipkinBatch":
if flag.NArg()-1 != 1 {
fmt.Fprintln(os.Stderr, "SubmitZipkinBatch requires 1 args")
flag.Usage()
}
arg11 := flag.Arg(1)
mbTrans12 := thrift.NewTMemoryBufferLen(len(arg11))
defer mbTrans12.Close()
_, err13 := mbTrans12.WriteString(arg11)
if err13 != nil {
Usage()
return
}
factory14 := thrift.NewTJSONProtocolFactory()
jsProt15 := factory14.GetProtocol(mbTrans12)
containerStruct0 := zipkincore.NewZipkinCollectorSubmitZipkinBatchArgs()
err16 := containerStruct0.ReadField1(context.Background(), jsProt15)
if err16 != nil {
Usage()
return
}
argvalue0 := containerStruct0.Spans
value0 := argvalue0
fmt.Print(client.SubmitZipkinBatch(context.Background(), value0))
fmt.Print("\n")
break
case "":
Usage()
break
default:
fmt.Fprintln(os.Stderr, "Invalid function ", cmd)
}
}
@@ -0,0 +1,38 @@
// Code generated by Thrift Compiler (0.14.1). DO NOT EDIT.
package zipkincore
import (
"bytes"
"context"
"fmt"
"go.opentelemetry.io/otel/exporters/trace/jaeger/internal/third_party/thrift/lib/go/thrift"
"time"
)
// (needed to ensure safety because of naive import list construction.)
var _ = thrift.ZERO
var _ = fmt.Printf
var _ = context.Background
var _ = time.Now
var _ = bytes.Equal
const CLIENT_SEND = "cs"
const CLIENT_RECV = "cr"
const SERVER_SEND = "ss"
const SERVER_RECV = "sr"
const MESSAGE_SEND = "ms"
const MESSAGE_RECV = "mr"
const WIRE_SEND = "ws"
const WIRE_RECV = "wr"
const CLIENT_SEND_FRAGMENT = "csf"
const CLIENT_RECV_FRAGMENT = "crf"
const SERVER_SEND_FRAGMENT = "ssf"
const SERVER_RECV_FRAGMENT = "srf"
const LOCAL_COMPONENT = "lc"
const CLIENT_ADDR = "ca"
const SERVER_ADDR = "sa"
const MESSAGE_ADDR = "ma"
func init() {
}
File diff suppressed because it is too large Load Diff
@@ -236,4 +236,71 @@ For the lib/nodejs/lib/thrift/json_parse.js:
*/
(By Douglas Crockford <douglas@crockford.com>)
--------------------------------------------------
For lib/cpp/src/thrift/windows/SocketPair.cpp
/* socketpair.c
* Copyright 2007 by Nathan C. Myers <ncm@cantrip.org>; some rights reserved.
* This code is Free Software. It may be copied freely, in original or
* modified form, subject only to the restrictions that (1) the author is
* relieved from all responsibilities for any use for any purpose, and (2)
* this copyright notice must be retained, unchanged, in its entirety. If
* for any reason the author might be held responsible for any consequences
* of copying or use, license is withheld.
*/
--------------------------------------------------
For lib/py/compat/win32/stdint.h
// ISO C9x compliant stdint.h for Microsoft Visual Studio
// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
//
// Copyright (c) 2006-2008 Alexander Chemeris
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. The name of the author may be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
///////////////////////////////////////////////////////////////////////////////
--------------------------------------------------
Codegen template in t_html_generator.h
* Bootstrap v2.0.3
*
* Copyright 2012 Twitter, Inc
* Licensed under the Apache License v2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
* Designed and built with all the love in the world @twitter by @mdo and @fat.
---------------------------------------------------
For t_cl_generator.cc
* Copyright (c) 2008- Patrick Collison <patrick@collison.ie>
* Copyright (c) 2006- Facebook
---------------------------------------------------
@@ -19,6 +19,10 @@
package thrift
import (
"context"
)
const (
UNKNOWN_APPLICATION_EXCEPTION = 0
UNKNOWN_METHOD = 1
@@ -51,8 +55,8 @@ var defaultApplicationExceptionMessage = map[int32]string{
type TApplicationException interface {
TException
TypeId() int32
Read(iprot TProtocol) error
Write(oprot TProtocol) error
Read(ctx context.Context, iprot TProtocol) error
Write(ctx context.Context, oprot TProtocol) error
}
type tApplicationException struct {
@@ -60,6 +64,12 @@ type tApplicationException struct {
type_ int32
}
var _ TApplicationException = (*tApplicationException)(nil)
func (tApplicationException) TExceptionType() TExceptionType {
return TExceptionTypeApplication
}
func (e tApplicationException) Error() string {
if e.message != "" {
return e.message
@@ -75,9 +85,9 @@ func (p *tApplicationException) TypeId() int32 {
return p.type_
}
func (p *tApplicationException) Read(iprot TProtocol) error {
func (p *tApplicationException) Read(ctx context.Context, iprot TProtocol) error {
// TODO: this should really be generated by the compiler
_, err := iprot.ReadStructBegin()
_, err := iprot.ReadStructBegin(ctx)
if err != nil {
return err
}
@@ -86,7 +96,7 @@ func (p *tApplicationException) Read(iprot TProtocol) error {
type_ := int32(UNKNOWN_APPLICATION_EXCEPTION)
for {
_, ttype, id, err := iprot.ReadFieldBegin()
_, ttype, id, err := iprot.ReadFieldBegin(ctx)
if err != nil {
return err
}
@@ -96,34 +106,34 @@ func (p *tApplicationException) Read(iprot TProtocol) error {
switch id {
case 1:
if ttype == STRING {
if message, err = iprot.ReadString(); err != nil {
if message, err = iprot.ReadString(ctx); err != nil {
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
if err = SkipDefaultDepth(ctx, iprot, ttype); err != nil {
return err
}
}
case 2:
if ttype == I32 {
if type_, err = iprot.ReadI32(); err != nil {
if type_, err = iprot.ReadI32(ctx); err != nil {
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
if err = SkipDefaultDepth(ctx, iprot, ttype); err != nil {
return err
}
}
default:
if err = SkipDefaultDepth(iprot, ttype); err != nil {
if err = SkipDefaultDepth(ctx, iprot, ttype); err != nil {
return err
}
}
if err = iprot.ReadFieldEnd(); err != nil {
if err = iprot.ReadFieldEnd(ctx); err != nil {
return err
}
}
if err := iprot.ReadStructEnd(); err != nil {
if err := iprot.ReadStructEnd(ctx); err != nil {
return err
}
@@ -133,38 +143,38 @@ func (p *tApplicationException) Read(iprot TProtocol) error {
return nil
}
func (p *tApplicationException) Write(oprot TProtocol) (err error) {
err = oprot.WriteStructBegin("TApplicationException")
func (p *tApplicationException) Write(ctx context.Context, oprot TProtocol) (err error) {
err = oprot.WriteStructBegin(ctx, "TApplicationException")
if len(p.Error()) > 0 {
err = oprot.WriteFieldBegin("message", STRING, 1)
err = oprot.WriteFieldBegin(ctx, "message", STRING, 1)
if err != nil {
return
}
err = oprot.WriteString(p.Error())
err = oprot.WriteString(ctx, p.Error())
if err != nil {
return
}
err = oprot.WriteFieldEnd()
err = oprot.WriteFieldEnd(ctx)
if err != nil {
return
}
}
err = oprot.WriteFieldBegin("type", I32, 2)
err = oprot.WriteFieldBegin(ctx, "type", I32, 2)
if err != nil {
return
}
err = oprot.WriteI32(p.type_)
err = oprot.WriteI32(ctx, p.type_)
if err != nil {
return
}
err = oprot.WriteFieldEnd()
err = oprot.WriteFieldEnd(ctx)
if err != nil {
return
}
err = oprot.WriteFieldStop()
err = oprot.WriteFieldStop(ctx)
if err != nil {
return
}
err = oprot.WriteStructEnd()
err = oprot.WriteStructEnd(ctx)
return
}
@@ -32,22 +32,37 @@ import (
type TBinaryProtocol struct {
trans TRichTransport
origTransport TTransport
strictRead bool
strictWrite bool
cfg *TConfiguration
buffer [64]byte
}
type TBinaryProtocolFactory struct {
strictRead bool
strictWrite bool
cfg *TConfiguration
}
// Deprecated: Use NewTBinaryProtocolConf instead.
func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol {
return NewTBinaryProtocol(t, false, true)
return NewTBinaryProtocolConf(t, &TConfiguration{
noPropagation: true,
})
}
// Deprecated: Use NewTBinaryProtocolConf instead.
func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol {
p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite}
return NewTBinaryProtocolConf(t, &TConfiguration{
TBinaryStrictRead: &strictRead,
TBinaryStrictWrite: &strictWrite,
noPropagation: true,
})
}
func NewTBinaryProtocolConf(t TTransport, conf *TConfiguration) *TBinaryProtocol {
PropagateTConfiguration(t, conf)
p := &TBinaryProtocol{
origTransport: t,
cfg: conf,
}
if et, ok := t.(TRichTransport); ok {
p.trans = et
} else {
@@ -56,162 +71,181 @@ func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProt
return p
}
// Deprecated: Use NewTBinaryProtocolFactoryConf instead.
func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory {
return NewTBinaryProtocolFactory(false, true)
return NewTBinaryProtocolFactoryConf(&TConfiguration{
noPropagation: true,
})
}
// Deprecated: Use NewTBinaryProtocolFactoryConf instead.
func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory {
return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite}
return NewTBinaryProtocolFactoryConf(&TConfiguration{
TBinaryStrictRead: &strictRead,
TBinaryStrictWrite: &strictWrite,
noPropagation: true,
})
}
func NewTBinaryProtocolFactoryConf(conf *TConfiguration) *TBinaryProtocolFactory {
return &TBinaryProtocolFactory{
cfg: conf,
}
}
func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
return NewTBinaryProtocol(t, p.strictRead, p.strictWrite)
return NewTBinaryProtocolConf(t, p.cfg)
}
func (p *TBinaryProtocolFactory) SetTConfiguration(conf *TConfiguration) {
p.cfg = conf
}
/**
* Writing Methods
*/
func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
if p.strictWrite {
func (p *TBinaryProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error {
if p.cfg.GetTBinaryStrictWrite() {
version := uint32(VERSION_1) | uint32(typeId)
e := p.WriteI32(int32(version))
e := p.WriteI32(ctx, int32(version))
if e != nil {
return e
}
e = p.WriteString(name)
e = p.WriteString(ctx, name)
if e != nil {
return e
}
e = p.WriteI32(seqId)
e = p.WriteI32(ctx, seqId)
return e
} else {
e := p.WriteString(name)
e := p.WriteString(ctx, name)
if e != nil {
return e
}
e = p.WriteByte(int8(typeId))
e = p.WriteByte(ctx, int8(typeId))
if e != nil {
return e
}
e = p.WriteI32(seqId)
e = p.WriteI32(ctx, seqId)
return e
}
return nil
}
func (p *TBinaryProtocol) WriteMessageEnd() error {
func (p *TBinaryProtocol) WriteMessageEnd(ctx context.Context) error {
return nil
}
func (p *TBinaryProtocol) WriteStructBegin(name string) error {
func (p *TBinaryProtocol) WriteStructBegin(ctx context.Context, name string) error {
return nil
}
func (p *TBinaryProtocol) WriteStructEnd() error {
func (p *TBinaryProtocol) WriteStructEnd(ctx context.Context) error {
return nil
}
func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
e := p.WriteByte(int8(typeId))
func (p *TBinaryProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
e := p.WriteByte(ctx, int8(typeId))
if e != nil {
return e
}
e = p.WriteI16(id)
e = p.WriteI16(ctx, id)
return e
}
func (p *TBinaryProtocol) WriteFieldEnd() error {
func (p *TBinaryProtocol) WriteFieldEnd(ctx context.Context) error {
return nil
}
func (p *TBinaryProtocol) WriteFieldStop() error {
e := p.WriteByte(STOP)
func (p *TBinaryProtocol) WriteFieldStop(ctx context.Context) error {
e := p.WriteByte(ctx, STOP)
return e
}
func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
e := p.WriteByte(int8(keyType))
func (p *TBinaryProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
e := p.WriteByte(ctx, int8(keyType))
if e != nil {
return e
}
e = p.WriteByte(int8(valueType))
e = p.WriteByte(ctx, int8(valueType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
e = p.WriteI32(ctx, int32(size))
return e
}
func (p *TBinaryProtocol) WriteMapEnd() error {
func (p *TBinaryProtocol) WriteMapEnd(ctx context.Context) error {
return nil
}
func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error {
e := p.WriteByte(int8(elemType))
func (p *TBinaryProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
e := p.WriteByte(ctx, int8(elemType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
e = p.WriteI32(ctx, int32(size))
return e
}
func (p *TBinaryProtocol) WriteListEnd() error {
func (p *TBinaryProtocol) WriteListEnd(ctx context.Context) error {
return nil
}
func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error {
e := p.WriteByte(int8(elemType))
func (p *TBinaryProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
e := p.WriteByte(ctx, int8(elemType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
e = p.WriteI32(ctx, int32(size))
return e
}
func (p *TBinaryProtocol) WriteSetEnd() error {
func (p *TBinaryProtocol) WriteSetEnd(ctx context.Context) error {
return nil
}
func (p *TBinaryProtocol) WriteBool(value bool) error {
func (p *TBinaryProtocol) WriteBool(ctx context.Context, value bool) error {
if value {
return p.WriteByte(1)
return p.WriteByte(ctx, 1)
}
return p.WriteByte(0)
return p.WriteByte(ctx, 0)
}
func (p *TBinaryProtocol) WriteByte(value int8) error {
func (p *TBinaryProtocol) WriteByte(ctx context.Context, value int8) error {
e := p.trans.WriteByte(byte(value))
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI16(value int16) error {
func (p *TBinaryProtocol) WriteI16(ctx context.Context, value int16) error {
v := p.buffer[0:2]
binary.BigEndian.PutUint16(v, uint16(value))
_, e := p.trans.Write(v)
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI32(value int32) error {
func (p *TBinaryProtocol) WriteI32(ctx context.Context, value int32) error {
v := p.buffer[0:4]
binary.BigEndian.PutUint32(v, uint32(value))
_, e := p.trans.Write(v)
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI64(value int64) error {
func (p *TBinaryProtocol) WriteI64(ctx context.Context, value int64) error {
v := p.buffer[0:8]
binary.BigEndian.PutUint64(v, uint64(value))
_, err := p.trans.Write(v)
return NewTProtocolException(err)
}
func (p *TBinaryProtocol) WriteDouble(value float64) error {
return p.WriteI64(int64(math.Float64bits(value)))
func (p *TBinaryProtocol) WriteDouble(ctx context.Context, value float64) error {
return p.WriteI64(ctx, int64(math.Float64bits(value)))
}
func (p *TBinaryProtocol) WriteString(value string) error {
e := p.WriteI32(int32(len(value)))
func (p *TBinaryProtocol) WriteString(ctx context.Context, value string) error {
e := p.WriteI32(ctx, int32(len(value)))
if e != nil {
return e
}
@@ -219,8 +253,8 @@ func (p *TBinaryProtocol) WriteString(value string) error {
return NewTProtocolException(err)
}
func (p *TBinaryProtocol) WriteBinary(value []byte) error {
e := p.WriteI32(int32(len(value)))
func (p *TBinaryProtocol) WriteBinary(ctx context.Context, value []byte) error {
e := p.WriteI32(ctx, int32(len(value)))
if e != nil {
return e
}
@@ -232,8 +266,8 @@ func (p *TBinaryProtocol) WriteBinary(value []byte) error {
* Reading methods
*/
func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
size, e := p.ReadI32()
func (p *TBinaryProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {
size, e := p.ReadI32(ctx)
if e != nil {
return "", typeId, 0, NewTProtocolException(e)
}
@@ -243,79 +277,79 @@ func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType,
if version != VERSION_1 {
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin"))
}
name, e = p.ReadString()
name, e = p.ReadString(ctx)
if e != nil {
return name, typeId, seqId, NewTProtocolException(e)
}
seqId, e = p.ReadI32()
seqId, e = p.ReadI32(ctx)
if e != nil {
return name, typeId, seqId, NewTProtocolException(e)
}
return name, typeId, seqId, nil
}
if p.strictRead {
if p.cfg.GetTBinaryStrictRead() {
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin"))
}
name, e2 := p.readStringBody(size)
if e2 != nil {
return name, typeId, seqId, e2
}
b, e3 := p.ReadByte()
b, e3 := p.ReadByte(ctx)
if e3 != nil {
return name, typeId, seqId, e3
}
typeId = TMessageType(b)
seqId, e4 := p.ReadI32()
seqId, e4 := p.ReadI32(ctx)
if e4 != nil {
return name, typeId, seqId, e4
}
return name, typeId, seqId, nil
}
func (p *TBinaryProtocol) ReadMessageEnd() error {
func (p *TBinaryProtocol) ReadMessageEnd(ctx context.Context) error {
return nil
}
func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) {
func (p *TBinaryProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
return
}
func (p *TBinaryProtocol) ReadStructEnd() error {
func (p *TBinaryProtocol) ReadStructEnd(ctx context.Context) error {
return nil
}
func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) {
t, err := p.ReadByte()
func (p *TBinaryProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, seqId int16, err error) {
t, err := p.ReadByte(ctx)
typeId = TType(t)
if err != nil {
return name, typeId, seqId, err
}
if t != STOP {
seqId, err = p.ReadI16()
seqId, err = p.ReadI16(ctx)
}
return name, typeId, seqId, err
}
func (p *TBinaryProtocol) ReadFieldEnd() error {
func (p *TBinaryProtocol) ReadFieldEnd(ctx context.Context) error {
return nil
}
var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length"))
func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) {
k, e := p.ReadByte()
func (p *TBinaryProtocol) ReadMapBegin(ctx context.Context) (kType, vType TType, size int, err error) {
k, e := p.ReadByte(ctx)
if e != nil {
err = NewTProtocolException(e)
return
}
kType = TType(k)
v, e := p.ReadByte()
v, e := p.ReadByte(ctx)
if e != nil {
err = NewTProtocolException(e)
return
}
vType = TType(v)
size32, e := p.ReadI32()
size32, e := p.ReadI32(ctx)
if e != nil {
err = NewTProtocolException(e)
return
@@ -328,18 +362,18 @@ func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err erro
return kType, vType, size, nil
}
func (p *TBinaryProtocol) ReadMapEnd() error {
func (p *TBinaryProtocol) ReadMapEnd(ctx context.Context) error {
return nil
}
func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) {
b, e := p.ReadByte()
func (p *TBinaryProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
b, e := p.ReadByte(ctx)
if e != nil {
err = NewTProtocolException(e)
return
}
elemType = TType(b)
size32, e := p.ReadI32()
size32, e := p.ReadI32(ctx)
if e != nil {
err = NewTProtocolException(e)
return
@@ -353,18 +387,18 @@ func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error)
return
}
func (p *TBinaryProtocol) ReadListEnd() error {
func (p *TBinaryProtocol) ReadListEnd(ctx context.Context) error {
return nil
}
func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) {
b, e := p.ReadByte()
func (p *TBinaryProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
b, e := p.ReadByte(ctx)
if e != nil {
err = NewTProtocolException(e)
return
}
elemType = TType(b)
size32, e := p.ReadI32()
size32, e := p.ReadI32(ctx)
if e != nil {
err = NewTProtocolException(e)
return
@@ -377,12 +411,12 @@ func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) {
return elemType, size, nil
}
func (p *TBinaryProtocol) ReadSetEnd() error {
func (p *TBinaryProtocol) ReadSetEnd(ctx context.Context) error {
return nil
}
func (p *TBinaryProtocol) ReadBool() (bool, error) {
b, e := p.ReadByte()
func (p *TBinaryProtocol) ReadBool(ctx context.Context) (bool, error) {
b, e := p.ReadByte(ctx)
v := true
if b != 1 {
v = false
@@ -390,64 +424,75 @@ func (p *TBinaryProtocol) ReadBool() (bool, error) {
return v, e
}
func (p *TBinaryProtocol) ReadByte() (int8, error) {
func (p *TBinaryProtocol) ReadByte(ctx context.Context) (int8, error) {
v, err := p.trans.ReadByte()
return int8(v), err
}
func (p *TBinaryProtocol) ReadI16() (value int16, err error) {
func (p *TBinaryProtocol) ReadI16(ctx context.Context) (value int16, err error) {
buf := p.buffer[0:2]
err = p.readAll(buf)
err = p.readAll(ctx, buf)
value = int16(binary.BigEndian.Uint16(buf))
return value, err
}
func (p *TBinaryProtocol) ReadI32() (value int32, err error) {
func (p *TBinaryProtocol) ReadI32(ctx context.Context) (value int32, err error) {
buf := p.buffer[0:4]
err = p.readAll(buf)
err = p.readAll(ctx, buf)
value = int32(binary.BigEndian.Uint32(buf))
return value, err
}
func (p *TBinaryProtocol) ReadI64() (value int64, err error) {
func (p *TBinaryProtocol) ReadI64(ctx context.Context) (value int64, err error) {
buf := p.buffer[0:8]
err = p.readAll(buf)
err = p.readAll(ctx, buf)
value = int64(binary.BigEndian.Uint64(buf))
return value, err
}
func (p *TBinaryProtocol) ReadDouble() (value float64, err error) {
func (p *TBinaryProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
buf := p.buffer[0:8]
err = p.readAll(buf)
err = p.readAll(ctx, buf)
value = math.Float64frombits(binary.BigEndian.Uint64(buf))
return value, err
}
func (p *TBinaryProtocol) ReadString() (value string, err error) {
size, e := p.ReadI32()
func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err error) {
size, e := p.ReadI32(ctx)
if e != nil {
return "", e
}
err = checkSizeForProtocol(size, p.cfg)
if err != nil {
return
}
if size < 0 {
err = invalidDataLength
return
}
if size == 0 {
return "", nil
}
if size < int32(len(p.buffer)) {
// Avoid allocation on small reads
buf := p.buffer[:size]
read, e := io.ReadFull(p.trans, buf)
return string(buf[:read]), NewTProtocolException(e)
}
return p.readStringBody(size)
}
func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
size, e := p.ReadI32()
func (p *TBinaryProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
size, e := p.ReadI32(ctx)
if e != nil {
return nil, e
}
if size < 0 {
return nil, invalidDataLength
if err := checkSizeForProtocol(size, p.cfg); err != nil {
return nil, err
}
isize := int(size)
buf := make([]byte, isize)
_, err := io.ReadFull(p.trans, buf)
buf, err := safeReadBytes(size, p.trans)
return buf, NewTProtocolException(err)
}
@@ -455,51 +500,56 @@ func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TBinaryProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
func (p *TBinaryProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
return SkipDefaultDepth(ctx, p, fieldType)
}
func (p *TBinaryProtocol) Transport() TTransport {
return p.origTransport
}
func (p *TBinaryProtocol) readAll(buf []byte) error {
_, err := io.ReadFull(p.trans, buf)
func (p *TBinaryProtocol) readAll(ctx context.Context, buf []byte) (err error) {
var read int
_, deadlineSet := ctx.Deadline()
for {
read, err = io.ReadFull(p.trans, buf)
if deadlineSet && read == 0 && isTimeoutError(err) && ctx.Err() == nil {
// This is I/O timeout without anything read,
// and we still have time left, keep retrying.
continue
}
// For anything else, don't retry
break
}
return NewTProtocolException(err)
}
const readLimit = 32768
func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
if size < 0 {
return "", nil
}
var (
buf bytes.Buffer
e error
b []byte
)
switch {
case int(size) <= len(p.buffer):
b = p.buffer[:size] // avoids allocation for small reads
case int(size) < readLimit:
b = make([]byte, size)
default:
b = make([]byte, readLimit)
}
for size > 0 {
_, e = io.ReadFull(p.trans, b)
buf.Write(b)
if e != nil {
break
}
size -= readLimit
if size < readLimit && size > 0 {
b = b[:size]
}
}
return buf.String(), NewTProtocolException(e)
buf, err := safeReadBytes(size, p.trans)
return string(buf), NewTProtocolException(err)
}
func (p *TBinaryProtocol) SetTConfiguration(conf *TConfiguration) {
PropagateTConfiguration(p.trans, conf)
PropagateTConfiguration(p.origTransport, conf)
p.cfg = conf
}
var (
_ TConfigurationSetter = (*TBinaryProtocolFactory)(nil)
_ TConfigurationSetter = (*TBinaryProtocol)(nil)
)
// This function is shared between TBinaryProtocol and TCompactProtocol.
//
// It tries to read size bytes from trans, in a way that prevents large
// allocations when size is insanely large (mostly caused by malformed message).
func safeReadBytes(size int32, trans io.Reader) ([]byte, error) {
if size < 0 {
return nil, nil
}
buf := new(bytes.Buffer)
_, err := io.CopyN(buf, trans, int64(size))
return buf.Bytes(), err
}
@@ -90,3 +90,10 @@ func (p *TBufferedTransport) Flush(ctx context.Context) error {
func (p *TBufferedTransport) RemainingBytes() (num_bytes uint64) {
return p.tp.RemainingBytes()
}
// SetTConfiguration implements TConfigurationSetter for propagation.
func (p *TBufferedTransport) SetTConfiguration(conf *TConfiguration) {
PropagateTConfiguration(p.tp, conf)
}
var _ TConfigurationSetter = (*TBufferedTransport)(nil)
@@ -5,8 +5,15 @@ import (
"fmt"
)
// ResponseMeta represents the metadata attached to the response.
type ResponseMeta struct {
// The headers in the response, if any.
// If the underlying transport/protocol is not THeader, this will always be nil.
Headers THeaderMap
}
type TClient interface {
Call(ctx context.Context, method string, args, result TStruct) error
Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error)
}
type TStandardClient struct {
@@ -34,20 +41,20 @@ func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32
}
}
if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil {
if err := oprot.WriteMessageBegin(ctx, method, CALL, seqId); err != nil {
return err
}
if err := args.Write(oprot); err != nil {
if err := args.Write(ctx, oprot); err != nil {
return err
}
if err := oprot.WriteMessageEnd(); err != nil {
if err := oprot.WriteMessageEnd(ctx); err != nil {
return err
}
return oprot.Flush(ctx)
}
func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, result TStruct) error {
rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin()
func (p *TStandardClient) Recv(ctx context.Context, iprot TProtocol, seqId int32, method string, result TStruct) error {
rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin(ctx)
if err != nil {
return err
}
@@ -58,11 +65,11 @@ func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, resu
return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
} else if rTypeId == EXCEPTION {
var exception tApplicationException
if err := exception.Read(iprot); err != nil {
if err := exception.Read(ctx, iprot); err != nil {
return err
}
if err := iprot.ReadMessageEnd(); err != nil {
if err := iprot.ReadMessageEnd(ctx); err != nil {
return err
}
@@ -71,25 +78,32 @@ func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, resu
return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
}
if err := result.Read(iprot); err != nil {
if err := result.Read(ctx, iprot); err != nil {
return err
}
return iprot.ReadMessageEnd()
return iprot.ReadMessageEnd(ctx)
}
func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error {
func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error) {
p.seqId++
seqId := p.seqId
if err := p.Send(ctx, p.oprot, seqId, method, args); err != nil {
return err
return ResponseMeta{}, err
}
// method is oneway
if result == nil {
return nil
return ResponseMeta{}, nil
}
return p.Recv(p.iprot, seqId, method, result)
err := p.Recv(ctx, p.iprot, seqId, method, result)
var headers THeaderMap
if hp, ok := p.iprot.(*THeaderProtocol); ok {
headers = hp.transport.readHeaders
}
return ResponseMeta{
Headers: headers,
}, err
}
@@ -22,6 +22,7 @@ package thrift
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
@@ -74,20 +75,37 @@ func init() {
}
}
type TCompactProtocolFactory struct{}
type TCompactProtocolFactory struct {
cfg *TConfiguration
}
// Deprecated: Use NewTCompactProtocolFactoryConf instead.
func NewTCompactProtocolFactory() *TCompactProtocolFactory {
return &TCompactProtocolFactory{}
return NewTCompactProtocolFactoryConf(&TConfiguration{
noPropagation: true,
})
}
func NewTCompactProtocolFactoryConf(conf *TConfiguration) *TCompactProtocolFactory {
return &TCompactProtocolFactory{
cfg: conf,
}
}
func (p *TCompactProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return NewTCompactProtocol(trans)
return NewTCompactProtocolConf(trans, p.cfg)
}
func (p *TCompactProtocolFactory) SetTConfiguration(conf *TConfiguration) {
p.cfg = conf
}
type TCompactProtocol struct {
trans TRichTransport
origTransport TTransport
cfg *TConfiguration
// Used to keep track of the last field for the current and previous structs,
// so we can do the delta stuff.
lastField []int
@@ -106,9 +124,19 @@ type TCompactProtocol struct {
buffer [64]byte
}
// Create a TCompactProtocol given a TTransport
// Deprecated: Use NewTCompactProtocolConf instead.
func NewTCompactProtocol(trans TTransport) *TCompactProtocol {
p := &TCompactProtocol{origTransport: trans, lastField: []int{}}
return NewTCompactProtocolConf(trans, &TConfiguration{
noPropagation: true,
})
}
func NewTCompactProtocolConf(trans TTransport, conf *TConfiguration) *TCompactProtocol {
PropagateTConfiguration(trans, conf)
p := &TCompactProtocol{
origTransport: trans,
cfg: conf,
}
if et, ok := trans.(TRichTransport); ok {
p.trans = et
} else {
@@ -116,7 +144,6 @@ func NewTCompactProtocol(trans TTransport) *TCompactProtocol {
}
return p
}
//
@@ -125,7 +152,7 @@ func NewTCompactProtocol(trans TTransport) *TCompactProtocol {
// Write a message header to the wire. Compact Protocol messages contain the
// protocol version so we can migrate forwards in the future if need be.
func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
func (p *TCompactProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error {
err := p.writeByteDirect(COMPACT_PROTOCOL_ID)
if err != nil {
return NewTProtocolException(err)
@@ -138,17 +165,17 @@ func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, s
if err != nil {
return NewTProtocolException(err)
}
e := p.WriteString(name)
e := p.WriteString(ctx, name)
return e
}
func (p *TCompactProtocol) WriteMessageEnd() error { return nil }
func (p *TCompactProtocol) WriteMessageEnd(ctx context.Context) error { return nil }
// Write a struct begin. This doesn't actually put anything on the wire. We
// use it as an opportunity to put special placeholder markers on the field
// stack so we can get the field id deltas correct.
func (p *TCompactProtocol) WriteStructBegin(name string) error {
func (p *TCompactProtocol) WriteStructBegin(ctx context.Context, name string) error {
p.lastField = append(p.lastField, p.lastFieldId)
p.lastFieldId = 0
return nil
@@ -157,26 +184,29 @@ func (p *TCompactProtocol) WriteStructBegin(name string) error {
// Write a struct end. This doesn't actually put anything on the wire. We use
// this as an opportunity to pop the last field from the current struct off
// of the field stack.
func (p *TCompactProtocol) WriteStructEnd() error {
func (p *TCompactProtocol) WriteStructEnd(ctx context.Context) error {
if len(p.lastField) <= 0 {
return NewTProtocolExceptionWithType(INVALID_DATA, errors.New("WriteStructEnd called without matching WriteStructBegin call before"))
}
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
}
func (p *TCompactProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
func (p *TCompactProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
if typeId == BOOL {
// we want to possibly include the value, so we'll wait.
p.booleanFieldName, p.booleanFieldId, p.booleanFieldPending = name, id, true
return nil
}
_, err := p.writeFieldBeginInternal(name, typeId, id, 0xFF)
_, err := p.writeFieldBeginInternal(ctx, name, typeId, id, 0xFF)
return NewTProtocolException(err)
}
// The workhorse of writeFieldBegin. It has the option of doing a
// 'type override' of the type header. This is used specifically in the
// boolean field case.
func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id int16, typeOverride byte) (int, error) {
func (p *TCompactProtocol) writeFieldBeginInternal(ctx context.Context, name string, typeId TType, id int16, typeOverride byte) (int, error) {
// short lastField = lastField_.pop();
// if there's a type override, use that.
@@ -201,7 +231,7 @@ func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id
if err != nil {
return 0, err
}
err = p.WriteI16(id)
err = p.WriteI16(ctx, id)
written = 1 + 2
if err != nil {
return 0, err
@@ -209,18 +239,17 @@ func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id
}
p.lastFieldId = fieldId
// p.lastField.Push(field.id);
return written, nil
}
func (p *TCompactProtocol) WriteFieldEnd() error { return nil }
func (p *TCompactProtocol) WriteFieldEnd(ctx context.Context) error { return nil }
func (p *TCompactProtocol) WriteFieldStop() error {
func (p *TCompactProtocol) WriteFieldStop(ctx context.Context) error {
err := p.writeByteDirect(STOP)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
func (p *TCompactProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
if size == 0 {
err := p.writeByteDirect(0)
return NewTProtocolException(err)
@@ -233,32 +262,32 @@ func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size in
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteMapEnd() error { return nil }
func (p *TCompactProtocol) WriteMapEnd(ctx context.Context) error { return nil }
// Write a list header.
func (p *TCompactProtocol) WriteListBegin(elemType TType, size int) error {
func (p *TCompactProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
_, err := p.writeCollectionBegin(elemType, size)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteListEnd() error { return nil }
func (p *TCompactProtocol) WriteListEnd(ctx context.Context) error { return nil }
// Write a set header.
func (p *TCompactProtocol) WriteSetBegin(elemType TType, size int) error {
func (p *TCompactProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
_, err := p.writeCollectionBegin(elemType, size)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteSetEnd() error { return nil }
func (p *TCompactProtocol) WriteSetEnd(ctx context.Context) error { return nil }
func (p *TCompactProtocol) WriteBool(value bool) error {
func (p *TCompactProtocol) WriteBool(ctx context.Context, value bool) error {
v := byte(COMPACT_BOOLEAN_FALSE)
if value {
v = byte(COMPACT_BOOLEAN_TRUE)
}
if p.booleanFieldPending {
// we haven't written the field header yet
_, err := p.writeFieldBeginInternal(p.booleanFieldName, BOOL, p.booleanFieldId, v)
_, err := p.writeFieldBeginInternal(ctx, p.booleanFieldName, BOOL, p.booleanFieldId, v)
p.booleanFieldPending = false
return NewTProtocolException(err)
}
@@ -268,31 +297,31 @@ func (p *TCompactProtocol) WriteBool(value bool) error {
}
// Write a byte. Nothing to see here!
func (p *TCompactProtocol) WriteByte(value int8) error {
func (p *TCompactProtocol) WriteByte(ctx context.Context, value int8) error {
err := p.writeByteDirect(byte(value))
return NewTProtocolException(err)
}
// Write an I16 as a zigzag varint.
func (p *TCompactProtocol) WriteI16(value int16) error {
func (p *TCompactProtocol) WriteI16(ctx context.Context, value int16) error {
_, err := p.writeVarint32(p.int32ToZigzag(int32(value)))
return NewTProtocolException(err)
}
// Write an i32 as a zigzag varint.
func (p *TCompactProtocol) WriteI32(value int32) error {
func (p *TCompactProtocol) WriteI32(ctx context.Context, value int32) error {
_, err := p.writeVarint32(p.int32ToZigzag(value))
return NewTProtocolException(err)
}
// Write an i64 as a zigzag varint.
func (p *TCompactProtocol) WriteI64(value int64) error {
func (p *TCompactProtocol) WriteI64(ctx context.Context, value int64) error {
_, err := p.writeVarint64(p.int64ToZigzag(value))
return NewTProtocolException(err)
}
// Write a double to the wire as 8 bytes.
func (p *TCompactProtocol) WriteDouble(value float64) error {
func (p *TCompactProtocol) WriteDouble(ctx context.Context, value float64) error {
buf := p.buffer[0:8]
binary.LittleEndian.PutUint64(buf, math.Float64bits(value))
_, err := p.trans.Write(buf)
@@ -300,7 +329,7 @@ func (p *TCompactProtocol) WriteDouble(value float64) error {
}
// Write a string to the wire with a varint size preceding.
func (p *TCompactProtocol) WriteString(value string) error {
func (p *TCompactProtocol) WriteString(ctx context.Context, value string) error {
_, e := p.writeVarint32(int32(len(value)))
if e != nil {
return NewTProtocolException(e)
@@ -312,7 +341,7 @@ func (p *TCompactProtocol) WriteString(value string) error {
}
// Write a byte array, using a varint for the size.
func (p *TCompactProtocol) WriteBinary(bin []byte) error {
func (p *TCompactProtocol) WriteBinary(ctx context.Context, bin []byte) error {
_, e := p.writeVarint32(int32(len(bin)))
if e != nil {
return NewTProtocolException(e)
@@ -329,9 +358,20 @@ func (p *TCompactProtocol) WriteBinary(bin []byte) error {
//
// Read a message header.
func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
func (p *TCompactProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {
var protocolId byte
protocolId, err := p.readByteDirect()
_, deadlineSet := ctx.Deadline()
for {
protocolId, err = p.readByteDirect()
if deadlineSet && isTimeoutError(err) && ctx.Err() == nil {
// keep retrying I/O timeout errors since we still have
// time left
continue
}
// For anything else, don't retry
break
}
if err != nil {
return
}
@@ -358,15 +398,15 @@ func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType,
err = NewTProtocolException(e)
return
}
name, err = p.ReadString()
name, err = p.ReadString(ctx)
return
}
func (p *TCompactProtocol) ReadMessageEnd() error { return nil }
func (p *TCompactProtocol) ReadMessageEnd(ctx context.Context) error { return nil }
// Read a struct begin. There's nothing on the wire for this, but it is our
// opportunity to push a new struct begin marker onto the field stack.
func (p *TCompactProtocol) ReadStructBegin() (name string, err error) {
func (p *TCompactProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
p.lastField = append(p.lastField, p.lastFieldId)
p.lastFieldId = 0
return
@@ -374,15 +414,18 @@ func (p *TCompactProtocol) ReadStructBegin() (name string, err error) {
// Doesn't actually consume any wire data, just removes the last field for
// this struct from the field stack.
func (p *TCompactProtocol) ReadStructEnd() error {
func (p *TCompactProtocol) ReadStructEnd(ctx context.Context) error {
// consume the last field we read off the wire.
if len(p.lastField) <= 0 {
return NewTProtocolExceptionWithType(INVALID_DATA, errors.New("ReadStructEnd called without matching ReadStructBegin call before"))
}
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
}
// Read a field header off the wire.
func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) {
func (p *TCompactProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error) {
t, err := p.readByteDirect()
if err != nil {
return
@@ -397,7 +440,7 @@ func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16
modifier := int16((t & 0xf0) >> 4)
if modifier == 0 {
// not a delta. look ahead for the zigzag varint field id.
id, err = p.ReadI16()
id, err = p.ReadI16(ctx)
if err != nil {
return
}
@@ -423,12 +466,12 @@ func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16
return
}
func (p *TCompactProtocol) ReadFieldEnd() error { return nil }
func (p *TCompactProtocol) ReadFieldEnd(ctx context.Context) error { return nil }
// Read a map header off the wire. If the size is zero, skip reading the key
// and value type. This means that 0-length maps will yield TMaps without the
// "correct" types.
func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
func (p *TCompactProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) {
size32, e := p.readVarint32()
if e != nil {
err = NewTProtocolException(e)
@@ -452,13 +495,13 @@ func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size
return
}
func (p *TCompactProtocol) ReadMapEnd() error { return nil }
func (p *TCompactProtocol) ReadMapEnd(ctx context.Context) error { return nil }
// Read a list header off the wire. If the list size is 0-14, the size will
// be packed into the element type header. If it's a longer list, the 4 MSB
// of the element type header will be 0xF, and a varint will follow with the
// true size.
func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error) {
func (p *TCompactProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
size_and_type, err := p.readByteDirect()
if err != nil {
return
@@ -484,22 +527,22 @@ func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error)
return
}
func (p *TCompactProtocol) ReadListEnd() error { return nil }
func (p *TCompactProtocol) ReadListEnd(ctx context.Context) error { return nil }
// Read a set header off the wire. If the set size is 0-14, the size will
// be packed into the element type header. If it's a longer set, the 4 MSB
// of the element type header will be 0xF, and a varint will follow with the
// true size.
func (p *TCompactProtocol) ReadSetBegin() (elemType TType, size int, err error) {
return p.ReadListBegin()
func (p *TCompactProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
return p.ReadListBegin(ctx)
}
func (p *TCompactProtocol) ReadSetEnd() error { return nil }
func (p *TCompactProtocol) ReadSetEnd(ctx context.Context) error { return nil }
// Read a boolean off the wire. If this is a boolean field, the value should
// already have been read during readFieldBegin, so we'll just consume the
// pre-stored value. Otherwise, read a byte.
func (p *TCompactProtocol) ReadBool() (value bool, err error) {
func (p *TCompactProtocol) ReadBool(ctx context.Context) (value bool, err error) {
if p.boolValueIsNotNull {
p.boolValueIsNotNull = false
return p.boolValue, nil
@@ -509,7 +552,7 @@ func (p *TCompactProtocol) ReadBool() (value bool, err error) {
}
// Read a single byte off the wire. Nothing interesting here.
func (p *TCompactProtocol) ReadByte() (int8, error) {
func (p *TCompactProtocol) ReadByte(ctx context.Context) (int8, error) {
v, err := p.readByteDirect()
if err != nil {
return 0, NewTProtocolException(err)
@@ -518,13 +561,13 @@ func (p *TCompactProtocol) ReadByte() (int8, error) {
}
// Read an i16 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI16() (value int16, err error) {
v, err := p.ReadI32()
func (p *TCompactProtocol) ReadI16(ctx context.Context) (value int16, err error) {
v, err := p.ReadI32(ctx)
return int16(v), err
}
// Read an i32 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI32() (value int32, err error) {
func (p *TCompactProtocol) ReadI32(ctx context.Context) (value int32, err error) {
v, e := p.readVarint32()
if e != nil {
return 0, NewTProtocolException(e)
@@ -534,7 +577,7 @@ func (p *TCompactProtocol) ReadI32() (value int32, err error) {
}
// Read an i64 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI64() (value int64, err error) {
func (p *TCompactProtocol) ReadI64(ctx context.Context) (value int64, err error) {
v, e := p.readVarint64()
if e != nil {
return 0, NewTProtocolException(e)
@@ -544,7 +587,7 @@ func (p *TCompactProtocol) ReadI64() (value int64, err error) {
}
// No magic here - just read a double off the wire.
func (p *TCompactProtocol) ReadDouble() (value float64, err error) {
func (p *TCompactProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
longBits := p.buffer[0:8]
_, e := io.ReadFull(p.trans, longBits)
if e != nil {
@@ -554,43 +597,44 @@ func (p *TCompactProtocol) ReadDouble() (value float64, err error) {
}
// Reads a []byte (via readBinary), and then UTF-8 decodes it.
func (p *TCompactProtocol) ReadString() (value string, err error) {
func (p *TCompactProtocol) ReadString(ctx context.Context) (value string, err error) {
length, e := p.readVarint32()
if e != nil {
return "", NewTProtocolException(e)
}
if length < 0 {
return "", invalidDataLength
err = checkSizeForProtocol(length, p.cfg)
if err != nil {
return
}
if length == 0 {
return "", nil
}
var buf []byte
if length <= int32(len(p.buffer)) {
buf = p.buffer[0:length]
} else {
buf = make([]byte, length)
if length < int32(len(p.buffer)) {
// Avoid allocation on small reads
buf := p.buffer[:length]
read, e := io.ReadFull(p.trans, buf)
return string(buf[:read]), NewTProtocolException(e)
}
_, e = io.ReadFull(p.trans, buf)
buf, e := safeReadBytes(length, p.trans)
return string(buf), NewTProtocolException(e)
}
// Read a []byte from the wire.
func (p *TCompactProtocol) ReadBinary() (value []byte, err error) {
func (p *TCompactProtocol) ReadBinary(ctx context.Context) (value []byte, err error) {
length, e := p.readVarint32()
if e != nil {
return nil, NewTProtocolException(e)
}
err = checkSizeForProtocol(length, p.cfg)
if err != nil {
return
}
if length == 0 {
return []byte{}, nil
}
if length < 0 {
return nil, invalidDataLength
}
buf := make([]byte, length)
_, e = io.ReadFull(p.trans, buf)
buf, e := safeReadBytes(length, p.trans)
return buf, NewTProtocolException(e)
}
@@ -598,8 +642,8 @@ func (p *TCompactProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TCompactProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
func (p *TCompactProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
return SkipDefaultDepth(ctx, p, fieldType)
}
func (p *TCompactProtocol) Transport() TTransport {
@@ -801,10 +845,21 @@ func (p *TCompactProtocol) getTType(t tCompactType) (TType, error) {
case COMPACT_STRUCT:
return STRUCT, nil
}
return STOP, TException(fmt.Errorf("don't know what type: %v", t&0x0f))
return STOP, NewTProtocolException(fmt.Errorf("don't know what type: %v", t&0x0f))
}
// Given a TType value, find the appropriate TCompactProtocol.Types constant.
func (p *TCompactProtocol) getCompactType(t TType) tCompactType {
return ttypeToCompactType[t]
}
func (p *TCompactProtocol) SetTConfiguration(conf *TConfiguration) {
PropagateTConfiguration(p.trans, conf)
PropagateTConfiguration(p.origTransport, conf)
p.cfg = conf
}
var (
_ TConfigurationSetter = (*TCompactProtocolFactory)(nil)
_ TConfigurationSetter = (*TCompactProtocol)(nil)
)
@@ -0,0 +1,378 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"crypto/tls"
"fmt"
"time"
)
// Default TConfiguration values.
const (
DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024
DEFAULT_MAX_FRAME_SIZE = 16384000
DEFAULT_TBINARY_STRICT_READ = false
DEFAULT_TBINARY_STRICT_WRITE = true
DEFAULT_CONNECT_TIMEOUT = 0
DEFAULT_SOCKET_TIMEOUT = 0
)
// TConfiguration defines some configurations shared between TTransport,
// TProtocol, TTransportFactory, TProtocolFactory, and other implementations.
//
// When constructing TConfiguration, you only need to specify the non-default
// fields. All zero values have sane default values.
//
// Not all configurations defined are applicable to all implementations.
// Implementations are free to ignore the configurations not applicable to them.
//
// All functions attached to this type are nil-safe.
//
// See [1] for spec.
//
// NOTE: When using TConfiguration, fill in all the configurations you want to
// set across the stack, not only the ones you want to set in the immediate
// TTransport/TProtocol.
//
// For example, say you want to migrate this old code into using TConfiguration:
//
// sccket := thrift.NewTSocketTimeout("host:port", time.Second)
// transFactory := thrift.NewTFramedTransportFactoryMaxLength(
// thrift.NewTTransportFactory(),
// 1024 * 1024 * 256,
// )
// protoFactory := thrift.NewTBinaryProtocolFactory(true, true)
//
// This is the wrong way to do it because in the end the TConfiguration used by
// socket and transFactory will be overwritten by the one used by protoFactory
// because of TConfiguration propagation:
//
// // bad example, DO NOT USE
// sccket := thrift.NewTSocketConf("host:port", &thrift.TConfiguration{
// ConnectTimeout: time.Second,
// SocketTimeout: time.Second,
// })
// transFactory := thrift.NewTFramedTransportFactoryConf(
// thrift.NewTTransportFactory(),
// &thrift.TConfiguration{
// MaxFrameSize: 1024 * 1024 * 256,
// },
// )
// protoFactory := thrift.NewTBinaryProtocolFactoryConf(&thrift.TConfiguration{
// TBinaryStrictRead: thrift.BoolPtr(true),
// TBinaryStrictWrite: thrift.BoolPtr(true),
// })
//
// This is the correct way to do it:
//
// conf := &thrift.TConfiguration{
// ConnectTimeout: time.Second,
// SocketTimeout: time.Second,
//
// MaxFrameSize: 1024 * 1024 * 256,
//
// TBinaryStrictRead: thrift.BoolPtr(true),
// TBinaryStrictWrite: thrift.BoolPtr(true),
// }
// sccket := thrift.NewTSocketConf("host:port", conf)
// transFactory := thrift.NewTFramedTransportFactoryConf(thrift.NewTTransportFactory(), conf)
// protoFactory := thrift.NewTBinaryProtocolFactoryConf(conf)
//
// [1]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-tconfiguration.md
type TConfiguration struct {
// If <= 0, DEFAULT_MAX_MESSAGE_SIZE will be used instead.
MaxMessageSize int32
// If <= 0, DEFAULT_MAX_FRAME_SIZE will be used instead.
//
// Also if MaxMessageSize < MaxFrameSize,
// MaxMessageSize will be used instead.
MaxFrameSize int32
// Connect and socket timeouts to be used by TSocket and TSSLSocket.
//
// 0 means no timeout.
//
// If <0, DEFAULT_CONNECT_TIMEOUT and DEFAULT_SOCKET_TIMEOUT will be
// used.
ConnectTimeout time.Duration
SocketTimeout time.Duration
// TLS config to be used by TSSLSocket.
TLSConfig *tls.Config
// Strict read/write configurations for TBinaryProtocol.
//
// BoolPtr helper function is available to use literal values.
TBinaryStrictRead *bool
TBinaryStrictWrite *bool
// The wrapped protocol id to be used in THeader transport/protocol.
//
// THeaderProtocolIDPtr and THeaderProtocolIDPtrMust helper functions
// are provided to help filling this value.
THeaderProtocolID *THeaderProtocolID
// Used internally by deprecated constructors, to avoid overriding
// underlying TTransport/TProtocol's cfg by accidental propagations.
//
// For external users this is always false.
noPropagation bool
}
// GetMaxMessageSize returns the max message size an implementation should
// follow.
//
// It's nil-safe. DEFAULT_MAX_MESSAGE_SIZE will be returned if tc is nil.
func (tc *TConfiguration) GetMaxMessageSize() int32 {
if tc == nil || tc.MaxMessageSize <= 0 {
return DEFAULT_MAX_MESSAGE_SIZE
}
return tc.MaxMessageSize
}
// GetMaxFrameSize returns the max frame size an implementation should follow.
//
// It's nil-safe. DEFAULT_MAX_FRAME_SIZE will be returned if tc is nil.
//
// If the configured max message size is smaller than the configured max frame
// size, the smaller one will be returned instead.
func (tc *TConfiguration) GetMaxFrameSize() int32 {
if tc == nil {
return DEFAULT_MAX_FRAME_SIZE
}
maxFrameSize := tc.MaxFrameSize
if maxFrameSize <= 0 {
maxFrameSize = DEFAULT_MAX_FRAME_SIZE
}
if maxMessageSize := tc.GetMaxMessageSize(); maxMessageSize < maxFrameSize {
return maxMessageSize
}
return maxFrameSize
}
// GetConnectTimeout returns the connect timeout should be used by TSocket and
// TSSLSocket.
//
// It's nil-safe. If tc is nil, DEFAULT_CONNECT_TIMEOUT will be returned instead.
func (tc *TConfiguration) GetConnectTimeout() time.Duration {
if tc == nil || tc.ConnectTimeout < 0 {
return DEFAULT_CONNECT_TIMEOUT
}
return tc.ConnectTimeout
}
// GetSocketTimeout returns the socket timeout should be used by TSocket and
// TSSLSocket.
//
// It's nil-safe. If tc is nil, DEFAULT_SOCKET_TIMEOUT will be returned instead.
func (tc *TConfiguration) GetSocketTimeout() time.Duration {
if tc == nil || tc.SocketTimeout < 0 {
return DEFAULT_SOCKET_TIMEOUT
}
return tc.SocketTimeout
}
// GetTLSConfig returns the tls config should be used by TSSLSocket.
//
// It's nil-safe. If tc is nil, nil will be returned instead.
func (tc *TConfiguration) GetTLSConfig() *tls.Config {
if tc == nil {
return nil
}
return tc.TLSConfig
}
// GetTBinaryStrictRead returns the strict read configuration TBinaryProtocol
// should follow.
//
// It's nil-safe. DEFAULT_TBINARY_STRICT_READ will be returned if either tc or
// tc.TBinaryStrictRead is nil.
func (tc *TConfiguration) GetTBinaryStrictRead() bool {
if tc == nil || tc.TBinaryStrictRead == nil {
return DEFAULT_TBINARY_STRICT_READ
}
return *tc.TBinaryStrictRead
}
// GetTBinaryStrictWrite returns the strict read configuration TBinaryProtocol
// should follow.
//
// It's nil-safe. DEFAULT_TBINARY_STRICT_WRITE will be returned if either tc or
// tc.TBinaryStrictWrite is nil.
func (tc *TConfiguration) GetTBinaryStrictWrite() bool {
if tc == nil || tc.TBinaryStrictWrite == nil {
return DEFAULT_TBINARY_STRICT_WRITE
}
return *tc.TBinaryStrictWrite
}
// GetTHeaderProtocolID returns the THeaderProtocolID should be used by
// THeaderProtocol clients (for servers, they always use the same one as the
// client instead).
//
// It's nil-safe. If either tc or tc.THeaderProtocolID is nil,
// THeaderProtocolDefault will be returned instead.
// THeaderProtocolDefault will also be returned if configured value is invalid.
func (tc *TConfiguration) GetTHeaderProtocolID() THeaderProtocolID {
if tc == nil || tc.THeaderProtocolID == nil {
return THeaderProtocolDefault
}
protoID := *tc.THeaderProtocolID
if err := protoID.Validate(); err != nil {
return THeaderProtocolDefault
}
return protoID
}
// THeaderProtocolIDPtr validates and returns the pointer to id.
//
// If id is not a valid THeaderProtocolID, a pointer to THeaderProtocolDefault
// and the validation error will be returned.
func THeaderProtocolIDPtr(id THeaderProtocolID) (*THeaderProtocolID, error) {
err := id.Validate()
if err != nil {
id = THeaderProtocolDefault
}
return &id, err
}
// THeaderProtocolIDPtrMust validates and returns the pointer to id.
//
// It's similar to THeaderProtocolIDPtr, but it panics on validation errors
// instead of returning them.
func THeaderProtocolIDPtrMust(id THeaderProtocolID) *THeaderProtocolID {
ptr, err := THeaderProtocolIDPtr(id)
if err != nil {
panic(err)
}
return ptr
}
// TConfigurationSetter is an optional interface TProtocol, TTransport,
// TProtocolFactory, TTransportFactory, and other implementations can implement.
//
// It's intended to be called during intializations.
// The behavior of calling SetTConfiguration on a TTransport/TProtocol in the
// middle of a message is undefined:
// It may or may not change the behavior of the current processing message,
// and it may even cause the current message to fail.
//
// Note for implementations: SetTConfiguration might be called multiple times
// with the same value in quick successions due to the implementation of the
// propagation. Implementations should make SetTConfiguration as simple as
// possible (usually just overwrite the stored configuration and propagate it to
// the wrapped TTransports/TProtocols).
type TConfigurationSetter interface {
SetTConfiguration(*TConfiguration)
}
// PropagateTConfiguration propagates cfg to impl if impl implements
// TConfigurationSetter and cfg is non-nil, otherwise it does nothing.
//
// NOTE: nil cfg is not propagated. If you want to propagate a TConfiguration
// with everything being default value, use &TConfiguration{} explicitly instead.
func PropagateTConfiguration(impl interface{}, cfg *TConfiguration) {
if cfg == nil || cfg.noPropagation {
return
}
if setter, ok := impl.(TConfigurationSetter); ok {
setter.SetTConfiguration(cfg)
}
}
func checkSizeForProtocol(size int32, cfg *TConfiguration) error {
if size < 0 {
return NewTProtocolExceptionWithType(
NEGATIVE_SIZE,
fmt.Errorf("negative size: %d", size),
)
}
if size > cfg.GetMaxMessageSize() {
return NewTProtocolExceptionWithType(
SIZE_LIMIT,
fmt.Errorf("size exceeded max allowed: %d", size),
)
}
return nil
}
type tTransportFactoryConf struct {
delegate TTransportFactory
cfg *TConfiguration
}
func (f *tTransportFactoryConf) GetTransport(orig TTransport) (TTransport, error) {
trans, err := f.delegate.GetTransport(orig)
if err == nil {
PropagateTConfiguration(orig, f.cfg)
PropagateTConfiguration(trans, f.cfg)
}
return trans, err
}
func (f *tTransportFactoryConf) SetTConfiguration(cfg *TConfiguration) {
PropagateTConfiguration(f.delegate, f.cfg)
f.cfg = cfg
}
// TTransportFactoryConf wraps a TTransportFactory to propagate
// TConfiguration on the factory's GetTransport calls.
func TTransportFactoryConf(delegate TTransportFactory, conf *TConfiguration) TTransportFactory {
return &tTransportFactoryConf{
delegate: delegate,
cfg: conf,
}
}
type tProtocolFactoryConf struct {
delegate TProtocolFactory
cfg *TConfiguration
}
func (f *tProtocolFactoryConf) GetProtocol(trans TTransport) TProtocol {
proto := f.delegate.GetProtocol(trans)
PropagateTConfiguration(trans, f.cfg)
PropagateTConfiguration(proto, f.cfg)
return proto
}
func (f *tProtocolFactoryConf) SetTConfiguration(cfg *TConfiguration) {
PropagateTConfiguration(f.delegate, f.cfg)
f.cfg = cfg
}
// TProtocolFactoryConf wraps a TProtocolFactory to propagate
// TConfiguration on the factory's GetProtocol calls.
func TProtocolFactoryConf(delegate TProtocolFactory, conf *TConfiguration) TProtocolFactory {
return &tProtocolFactoryConf{
delegate: delegate,
cfg: conf,
}
}
var (
_ TConfigurationSetter = (*tTransportFactoryConf)(nil)
_ TConfigurationSetter = (*tProtocolFactoryConf)(nil)
)
@@ -21,23 +21,58 @@ package thrift
import (
"context"
"log"
"fmt"
)
type TDebugProtocol struct {
Delegate TProtocol
// Required. The actual TProtocol to do the read/write.
Delegate TProtocol
// Optional. The logger and prefix to log all the args/return values
// from Delegate TProtocol calls.
//
// If Logger is nil, StdLogger using stdlib log package with os.Stderr
// will be used. If disable logging is desired, set Logger to NopLogger
// explicitly instead of leaving it as nil/unset.
Logger Logger
LogPrefix string
// Optional. An TProtocol to duplicate everything read/written from Delegate.
//
// A typical use case of this is to use TSimpleJSONProtocol wrapping
// TMemoryBuffer in a middleware to json logging requests/responses.
//
// This feature is not available from TDebugProtocolFactory. In order to
// use it you have to construct TDebugProtocol directly, or set DuplicateTo
// field after getting a TDebugProtocol from the factory.
DuplicateTo TProtocol
}
type TDebugProtocolFactory struct {
Underlying TProtocolFactory
LogPrefix string
Logger Logger
}
// NewTDebugProtocolFactory creates a TDebugProtocolFactory.
//
// Deprecated: Please use NewTDebugProtocolFactoryWithLogger or the struct
// itself instead. This version will use the default logger from standard
// library.
func NewTDebugProtocolFactory(underlying TProtocolFactory, logPrefix string) *TDebugProtocolFactory {
return &TDebugProtocolFactory{
Underlying: underlying,
LogPrefix: logPrefix,
Logger: StdLogger(nil),
}
}
// NewTDebugProtocolFactoryWithLogger creates a TDebugProtocolFactory.
func NewTDebugProtocolFactoryWithLogger(underlying TProtocolFactory, logPrefix string, logger Logger) *TDebugProtocolFactory {
return &TDebugProtocolFactory{
Underlying: underlying,
LogPrefix: logPrefix,
Logger: logger,
}
}
@@ -45,226 +80,368 @@ func (t *TDebugProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return &TDebugProtocol{
Delegate: t.Underlying.GetProtocol(trans),
LogPrefix: t.LogPrefix,
Logger: fallbackLogger(t.Logger),
}
}
func (tdp *TDebugProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
err := tdp.Delegate.WriteMessageBegin(name, typeId, seqid)
log.Printf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err)
func (tdp *TDebugProtocol) logf(format string, v ...interface{}) {
fallbackLogger(tdp.Logger)(fmt.Sprintf(format, v...))
}
func (tdp *TDebugProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error {
err := tdp.Delegate.WriteMessageBegin(ctx, name, typeId, seqid)
tdp.logf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteMessageBegin(ctx, name, typeId, seqid)
}
return err
}
func (tdp *TDebugProtocol) WriteMessageEnd() error {
err := tdp.Delegate.WriteMessageEnd()
log.Printf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) WriteMessageEnd(ctx context.Context) error {
err := tdp.Delegate.WriteMessageEnd(ctx)
tdp.logf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteMessageEnd(ctx)
}
return err
}
func (tdp *TDebugProtocol) WriteStructBegin(name string) error {
err := tdp.Delegate.WriteStructBegin(name)
log.Printf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err)
func (tdp *TDebugProtocol) WriteStructBegin(ctx context.Context, name string) error {
err := tdp.Delegate.WriteStructBegin(ctx, name)
tdp.logf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteStructBegin(ctx, name)
}
return err
}
func (tdp *TDebugProtocol) WriteStructEnd() error {
err := tdp.Delegate.WriteStructEnd()
log.Printf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) WriteStructEnd(ctx context.Context) error {
err := tdp.Delegate.WriteStructEnd(ctx)
tdp.logf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteStructEnd(ctx)
}
return err
}
func (tdp *TDebugProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
err := tdp.Delegate.WriteFieldBegin(name, typeId, id)
log.Printf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err)
func (tdp *TDebugProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
err := tdp.Delegate.WriteFieldBegin(ctx, name, typeId, id)
tdp.logf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteFieldBegin(ctx, name, typeId, id)
}
return err
}
func (tdp *TDebugProtocol) WriteFieldEnd() error {
err := tdp.Delegate.WriteFieldEnd()
log.Printf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) WriteFieldEnd(ctx context.Context) error {
err := tdp.Delegate.WriteFieldEnd(ctx)
tdp.logf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteFieldEnd(ctx)
}
return err
}
func (tdp *TDebugProtocol) WriteFieldStop() error {
err := tdp.Delegate.WriteFieldStop()
log.Printf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) WriteFieldStop(ctx context.Context) error {
err := tdp.Delegate.WriteFieldStop(ctx)
tdp.logf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteFieldStop(ctx)
}
return err
}
func (tdp *TDebugProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
err := tdp.Delegate.WriteMapBegin(keyType, valueType, size)
log.Printf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err)
func (tdp *TDebugProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
err := tdp.Delegate.WriteMapBegin(ctx, keyType, valueType, size)
tdp.logf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteMapBegin(ctx, keyType, valueType, size)
}
return err
}
func (tdp *TDebugProtocol) WriteMapEnd() error {
err := tdp.Delegate.WriteMapEnd()
log.Printf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) WriteMapEnd(ctx context.Context) error {
err := tdp.Delegate.WriteMapEnd(ctx)
tdp.logf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteMapEnd(ctx)
}
return err
}
func (tdp *TDebugProtocol) WriteListBegin(elemType TType, size int) error {
err := tdp.Delegate.WriteListBegin(elemType, size)
log.Printf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
func (tdp *TDebugProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
err := tdp.Delegate.WriteListBegin(ctx, elemType, size)
tdp.logf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteListBegin(ctx, elemType, size)
}
return err
}
func (tdp *TDebugProtocol) WriteListEnd() error {
err := tdp.Delegate.WriteListEnd()
log.Printf("%sWriteListEnd() => %#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) WriteListEnd(ctx context.Context) error {
err := tdp.Delegate.WriteListEnd(ctx)
tdp.logf("%sWriteListEnd() => %#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteListEnd(ctx)
}
return err
}
func (tdp *TDebugProtocol) WriteSetBegin(elemType TType, size int) error {
err := tdp.Delegate.WriteSetBegin(elemType, size)
log.Printf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
func (tdp *TDebugProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
err := tdp.Delegate.WriteSetBegin(ctx, elemType, size)
tdp.logf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteSetBegin(ctx, elemType, size)
}
return err
}
func (tdp *TDebugProtocol) WriteSetEnd() error {
err := tdp.Delegate.WriteSetEnd()
log.Printf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) WriteSetEnd(ctx context.Context) error {
err := tdp.Delegate.WriteSetEnd(ctx)
tdp.logf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteSetEnd(ctx)
}
return err
}
func (tdp *TDebugProtocol) WriteBool(value bool) error {
err := tdp.Delegate.WriteBool(value)
log.Printf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) WriteBool(ctx context.Context, value bool) error {
err := tdp.Delegate.WriteBool(ctx, value)
tdp.logf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteBool(ctx, value)
}
return err
}
func (tdp *TDebugProtocol) WriteByte(value int8) error {
err := tdp.Delegate.WriteByte(value)
log.Printf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) WriteByte(ctx context.Context, value int8) error {
err := tdp.Delegate.WriteByte(ctx, value)
tdp.logf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteByte(ctx, value)
}
return err
}
func (tdp *TDebugProtocol) WriteI16(value int16) error {
err := tdp.Delegate.WriteI16(value)
log.Printf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) WriteI16(ctx context.Context, value int16) error {
err := tdp.Delegate.WriteI16(ctx, value)
tdp.logf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteI16(ctx, value)
}
return err
}
func (tdp *TDebugProtocol) WriteI32(value int32) error {
err := tdp.Delegate.WriteI32(value)
log.Printf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) WriteI32(ctx context.Context, value int32) error {
err := tdp.Delegate.WriteI32(ctx, value)
tdp.logf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteI32(ctx, value)
}
return err
}
func (tdp *TDebugProtocol) WriteI64(value int64) error {
err := tdp.Delegate.WriteI64(value)
log.Printf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) WriteI64(ctx context.Context, value int64) error {
err := tdp.Delegate.WriteI64(ctx, value)
tdp.logf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteI64(ctx, value)
}
return err
}
func (tdp *TDebugProtocol) WriteDouble(value float64) error {
err := tdp.Delegate.WriteDouble(value)
log.Printf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) WriteDouble(ctx context.Context, value float64) error {
err := tdp.Delegate.WriteDouble(ctx, value)
tdp.logf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteDouble(ctx, value)
}
return err
}
func (tdp *TDebugProtocol) WriteString(value string) error {
err := tdp.Delegate.WriteString(value)
log.Printf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) WriteString(ctx context.Context, value string) error {
err := tdp.Delegate.WriteString(ctx, value)
tdp.logf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteString(ctx, value)
}
return err
}
func (tdp *TDebugProtocol) WriteBinary(value []byte) error {
err := tdp.Delegate.WriteBinary(value)
log.Printf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) WriteBinary(ctx context.Context, value []byte) error {
err := tdp.Delegate.WriteBinary(ctx, value)
tdp.logf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteBinary(ctx, value)
}
return err
}
func (tdp *TDebugProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin()
log.Printf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err)
func (tdp *TDebugProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error) {
name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin(ctx)
tdp.logf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteMessageBegin(ctx, name, typeId, seqid)
}
return
}
func (tdp *TDebugProtocol) ReadMessageEnd() (err error) {
err = tdp.Delegate.ReadMessageEnd()
log.Printf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) ReadMessageEnd(ctx context.Context) (err error) {
err = tdp.Delegate.ReadMessageEnd(ctx)
tdp.logf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteMessageEnd(ctx)
}
return
}
func (tdp *TDebugProtocol) ReadStructBegin() (name string, err error) {
name, err = tdp.Delegate.ReadStructBegin()
log.Printf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err)
func (tdp *TDebugProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
name, err = tdp.Delegate.ReadStructBegin(ctx)
tdp.logf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteStructBegin(ctx, name)
}
return
}
func (tdp *TDebugProtocol) ReadStructEnd() (err error) {
err = tdp.Delegate.ReadStructEnd()
log.Printf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) ReadStructEnd(ctx context.Context) (err error) {
err = tdp.Delegate.ReadStructEnd(ctx)
tdp.logf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteStructEnd(ctx)
}
return
}
func (tdp *TDebugProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) {
name, typeId, id, err = tdp.Delegate.ReadFieldBegin()
log.Printf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err)
func (tdp *TDebugProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error) {
name, typeId, id, err = tdp.Delegate.ReadFieldBegin(ctx)
tdp.logf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteFieldBegin(ctx, name, typeId, id)
}
return
}
func (tdp *TDebugProtocol) ReadFieldEnd() (err error) {
err = tdp.Delegate.ReadFieldEnd()
log.Printf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) ReadFieldEnd(ctx context.Context) (err error) {
err = tdp.Delegate.ReadFieldEnd(ctx)
tdp.logf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteFieldEnd(ctx)
}
return
}
func (tdp *TDebugProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
keyType, valueType, size, err = tdp.Delegate.ReadMapBegin()
log.Printf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err)
func (tdp *TDebugProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) {
keyType, valueType, size, err = tdp.Delegate.ReadMapBegin(ctx)
tdp.logf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteMapBegin(ctx, keyType, valueType, size)
}
return
}
func (tdp *TDebugProtocol) ReadMapEnd() (err error) {
err = tdp.Delegate.ReadMapEnd()
log.Printf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) ReadMapEnd(ctx context.Context) (err error) {
err = tdp.Delegate.ReadMapEnd(ctx)
tdp.logf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteMapEnd(ctx)
}
return
}
func (tdp *TDebugProtocol) ReadListBegin() (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadListBegin()
log.Printf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
func (tdp *TDebugProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadListBegin(ctx)
tdp.logf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteListBegin(ctx, elemType, size)
}
return
}
func (tdp *TDebugProtocol) ReadListEnd() (err error) {
err = tdp.Delegate.ReadListEnd()
log.Printf("%sReadListEnd() err=%#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) ReadListEnd(ctx context.Context) (err error) {
err = tdp.Delegate.ReadListEnd(ctx)
tdp.logf("%sReadListEnd() err=%#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteListEnd(ctx)
}
return
}
func (tdp *TDebugProtocol) ReadSetBegin() (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadSetBegin()
log.Printf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
func (tdp *TDebugProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadSetBegin(ctx)
tdp.logf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteSetBegin(ctx, elemType, size)
}
return
}
func (tdp *TDebugProtocol) ReadSetEnd() (err error) {
err = tdp.Delegate.ReadSetEnd()
log.Printf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err)
func (tdp *TDebugProtocol) ReadSetEnd(ctx context.Context) (err error) {
err = tdp.Delegate.ReadSetEnd(ctx)
tdp.logf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteSetEnd(ctx)
}
return
}
func (tdp *TDebugProtocol) ReadBool() (value bool, err error) {
value, err = tdp.Delegate.ReadBool()
log.Printf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) ReadBool(ctx context.Context) (value bool, err error) {
value, err = tdp.Delegate.ReadBool(ctx)
tdp.logf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteBool(ctx, value)
}
return
}
func (tdp *TDebugProtocol) ReadByte() (value int8, err error) {
value, err = tdp.Delegate.ReadByte()
log.Printf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) ReadByte(ctx context.Context) (value int8, err error) {
value, err = tdp.Delegate.ReadByte(ctx)
tdp.logf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteByte(ctx, value)
}
return
}
func (tdp *TDebugProtocol) ReadI16() (value int16, err error) {
value, err = tdp.Delegate.ReadI16()
log.Printf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) ReadI16(ctx context.Context) (value int16, err error) {
value, err = tdp.Delegate.ReadI16(ctx)
tdp.logf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteI16(ctx, value)
}
return
}
func (tdp *TDebugProtocol) ReadI32() (value int32, err error) {
value, err = tdp.Delegate.ReadI32()
log.Printf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) ReadI32(ctx context.Context) (value int32, err error) {
value, err = tdp.Delegate.ReadI32(ctx)
tdp.logf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteI32(ctx, value)
}
return
}
func (tdp *TDebugProtocol) ReadI64() (value int64, err error) {
value, err = tdp.Delegate.ReadI64()
log.Printf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) ReadI64(ctx context.Context) (value int64, err error) {
value, err = tdp.Delegate.ReadI64(ctx)
tdp.logf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteI64(ctx, value)
}
return
}
func (tdp *TDebugProtocol) ReadDouble() (value float64, err error) {
value, err = tdp.Delegate.ReadDouble()
log.Printf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
value, err = tdp.Delegate.ReadDouble(ctx)
tdp.logf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteDouble(ctx, value)
}
return
}
func (tdp *TDebugProtocol) ReadString() (value string, err error) {
value, err = tdp.Delegate.ReadString()
log.Printf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) ReadString(ctx context.Context) (value string, err error) {
value, err = tdp.Delegate.ReadString(ctx)
tdp.logf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteString(ctx, value)
}
return
}
func (tdp *TDebugProtocol) ReadBinary() (value []byte, err error) {
value, err = tdp.Delegate.ReadBinary()
log.Printf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
func (tdp *TDebugProtocol) ReadBinary(ctx context.Context) (value []byte, err error) {
value, err = tdp.Delegate.ReadBinary(ctx)
tdp.logf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.WriteBinary(ctx, value)
}
return
}
func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) {
err = tdp.Delegate.Skip(fieldType)
log.Printf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err)
func (tdp *TDebugProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
err = tdp.Delegate.Skip(ctx, fieldType)
tdp.logf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.Skip(ctx, fieldType)
}
return
}
func (tdp *TDebugProtocol) Flush(ctx context.Context) (err error) {
err = tdp.Delegate.Flush(ctx)
log.Printf("%sFlush() (err=%#v)", tdp.LogPrefix, err)
tdp.logf("%sFlush() (err=%#v)", tdp.LogPrefix, err)
if tdp.DuplicateTo != nil {
tdp.DuplicateTo.Flush(ctx)
}
return
}
func (tdp *TDebugProtocol) Transport() TTransport {
return tdp.Delegate.Transport()
}
// SetTConfiguration implements TConfigurationSetter for propagation.
func (tdp *TDebugProtocol) SetTConfiguration(conf *TConfiguration) {
PropagateTConfiguration(tdp.Delegate, conf)
PropagateTConfiguration(tdp.DuplicateTo, conf)
}
var _ TConfigurationSetter = (*TDebugProtocol)(nil)
@@ -19,40 +19,103 @@
package thrift
import (
"context"
"sync"
)
type TDeserializer struct {
Transport TTransport
Transport *TMemoryBuffer
Protocol TProtocol
}
func NewTDeserializer() *TDeserializer {
var transport TTransport
transport = NewTMemoryBufferLen(1024)
protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport)
transport := NewTMemoryBufferLen(1024)
protocol := NewTBinaryProtocolTransport(transport)
return &TDeserializer{
transport,
protocol}
Transport: transport,
Protocol: protocol,
}
}
func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) {
func (t *TDeserializer) ReadString(ctx context.Context, msg TStruct, s string) (err error) {
t.Transport.Reset()
err = nil
if _, err = t.Transport.Write([]byte(s)); err != nil {
return
}
if err = msg.Read(t.Protocol); err != nil {
if err = msg.Read(ctx, t.Protocol); err != nil {
return
}
return
}
func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) {
func (t *TDeserializer) Read(ctx context.Context, msg TStruct, b []byte) (err error) {
t.Transport.Reset()
err = nil
if _, err = t.Transport.Write(b); err != nil {
return
}
if err = msg.Read(t.Protocol); err != nil {
if err = msg.Read(ctx, t.Protocol); err != nil {
return
}
return
}
// TDeserializerPool is the thread-safe version of TDeserializer,
// it uses resource pool of TDeserializer under the hood.
//
// It must be initialized with either NewTDeserializerPool or
// NewTDeserializerPoolSizeFactory.
type TDeserializerPool struct {
pool sync.Pool
}
// NewTDeserializerPool creates a new TDeserializerPool.
//
// NewTDeserializer can be used as the arg here.
func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool {
return &TDeserializerPool{
pool: sync.Pool{
New: func() interface{} {
return f()
},
},
}
}
// NewTDeserializerPoolSizeFactory creates a new TDeserializerPool with
// the given size and protocol factory.
//
// Note that the size is not the limit. The TMemoryBuffer underneath can grow
// larger than that. It just dictates the initial size.
func NewTDeserializerPoolSizeFactory(size int, factory TProtocolFactory) *TDeserializerPool {
return &TDeserializerPool{
pool: sync.Pool{
New: func() interface{} {
transport := NewTMemoryBufferLen(size)
protocol := factory.GetProtocol(transport)
return &TDeserializer{
Transport: transport,
Protocol: protocol,
}
},
},
}
}
func (t *TDeserializerPool) ReadString(ctx context.Context, msg TStruct, s string) error {
d := t.pool.Get().(*TDeserializer)
defer t.pool.Put(d)
return d.ReadString(ctx, msg, s)
}
func (t *TDeserializerPool) Read(ctx context.Context, msg TStruct, b []byte) error {
d := t.pool.Get().(*TDeserializer)
defer t.pool.Put(d)
return d.Read(ctx, msg, b)
}
@@ -26,19 +26,91 @@ import (
// Generic Thrift exception
type TException interface {
error
TExceptionType() TExceptionType
}
// Prepends additional information to an error without losing the Thrift exception interface
func PrependError(prepend string, err error) error {
if t, ok := err.(TTransportException); ok {
return NewTTransportException(t.TypeId(), prepend+t.Error())
}
if t, ok := err.(TProtocolException); ok {
return NewTProtocolExceptionWithType(t.TypeId(), errors.New(prepend+err.Error()))
}
if t, ok := err.(TApplicationException); ok {
return NewTApplicationException(t.TypeId(), prepend+t.Error())
msg := prepend + err.Error()
var te TException
if errors.As(err, &te) {
switch te.TExceptionType() {
case TExceptionTypeTransport:
if t, ok := err.(TTransportException); ok {
return prependTTransportException(prepend, t)
}
case TExceptionTypeProtocol:
if t, ok := err.(TProtocolException); ok {
return prependTProtocolException(prepend, t)
}
case TExceptionTypeApplication:
var t TApplicationException
if errors.As(err, &t) {
return NewTApplicationException(t.TypeId(), msg)
}
}
return wrappedTException{
err: err,
msg: msg,
tExceptionType: te.TExceptionType(),
}
}
return errors.New(prepend + err.Error())
return errors.New(msg)
}
// TExceptionType is an enum type to categorize different "subclasses" of TExceptions.
type TExceptionType byte
// TExceptionType values
const (
TExceptionTypeUnknown TExceptionType = iota
TExceptionTypeCompiled // TExceptions defined in thrift files and generated by thrift compiler
TExceptionTypeApplication // TApplicationExceptions
TExceptionTypeProtocol // TProtocolExceptions
TExceptionTypeTransport // TTransportExceptions
)
// WrapTException wraps an error into TException.
//
// If err is nil or already TException, it's returned as-is.
// Otherwise it will be wraped into TException with TExceptionType() returning
// TExceptionTypeUnknown, and Unwrap() returning the original error.
func WrapTException(err error) TException {
if err == nil {
return nil
}
if te, ok := err.(TException); ok {
return te
}
return wrappedTException{
err: err,
msg: err.Error(),
tExceptionType: TExceptionTypeUnknown,
}
}
type wrappedTException struct {
err error
msg string
tExceptionType TExceptionType
}
func (w wrappedTException) Error() string {
return w.msg
}
func (w wrappedTException) TExceptionType() TExceptionType {
return w.tExceptionType
}
func (w wrappedTException) Unwrap() error {
return w.err
}
var _ TException = wrappedTException{}
@@ -1,79 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Helper class that encapsulates field metadata.
type field struct {
name string
typeId TType
id int
}
func newField(n string, t TType, i int) *field {
return &field{name: n, typeId: t, id: i}
}
func (p *field) Name() string {
if p == nil {
return ""
}
return p.name
}
func (p *field) TypeId() TType {
if p == nil {
return TType(VOID)
}
return p.typeId
}
func (p *field) Id() int {
if p == nil {
return -1
}
return p.id
}
func (p *field) String() string {
if p == nil {
return "<nil>"
}
return "<TField name:'" + p.name + "' type:" + string(p.typeId) + " field-id:" + string(p.id) + ">"
}
var ANONYMOUS_FIELD *field
type fieldSlice []field
func (p fieldSlice) Len() int {
return len(p)
}
func (p fieldSlice) Less(i, j int) bool {
return p[i].Id() < p[j].Id()
}
func (p fieldSlice) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
func init() {
ANONYMOUS_FIELD = newField("", STOP, 0)
}
@@ -28,44 +28,92 @@ import (
"io"
)
// Deprecated: Use DEFAULT_MAX_FRAME_SIZE instead.
const DEFAULT_MAX_LENGTH = 16384000
type TFramedTransport struct {
transport TTransport
buf bytes.Buffer
reader *bufio.Reader
frameSize uint32 //Current remaining size of the frame. if ==0 read next frame header
buffer [4]byte
maxLength uint32
cfg *TConfiguration
writeBuf bytes.Buffer
reader *bufio.Reader
readBuf bytes.Buffer
buffer [4]byte
}
type tFramedTransportFactory struct {
factory TTransportFactory
maxLength uint32
factory TTransportFactory
cfg *TConfiguration
}
// Deprecated: Use NewTFramedTransportFactoryConf instead.
func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory {
return &tFramedTransportFactory{factory: factory, maxLength: DEFAULT_MAX_LENGTH}
return NewTFramedTransportFactoryConf(factory, &TConfiguration{
MaxFrameSize: DEFAULT_MAX_LENGTH,
noPropagation: true,
})
}
// Deprecated: Use NewTFramedTransportFactoryConf instead.
func NewTFramedTransportFactoryMaxLength(factory TTransportFactory, maxLength uint32) TTransportFactory {
return &tFramedTransportFactory{factory: factory, maxLength: maxLength}
return NewTFramedTransportFactoryConf(factory, &TConfiguration{
MaxFrameSize: int32(maxLength),
noPropagation: true,
})
}
func NewTFramedTransportFactoryConf(factory TTransportFactory, conf *TConfiguration) TTransportFactory {
PropagateTConfiguration(factory, conf)
return &tFramedTransportFactory{
factory: factory,
cfg: conf,
}
}
func (p *tFramedTransportFactory) GetTransport(base TTransport) (TTransport, error) {
PropagateTConfiguration(base, p.cfg)
tt, err := p.factory.GetTransport(base)
if err != nil {
return nil, err
}
return NewTFramedTransportMaxLength(tt, p.maxLength), nil
return NewTFramedTransportConf(tt, p.cfg), nil
}
func (p *tFramedTransportFactory) SetTConfiguration(cfg *TConfiguration) {
PropagateTConfiguration(p.factory, cfg)
p.cfg = cfg
}
// Deprecated: Use NewTFramedTransportConf instead.
func NewTFramedTransport(transport TTransport) *TFramedTransport {
return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: DEFAULT_MAX_LENGTH}
return NewTFramedTransportConf(transport, &TConfiguration{
MaxFrameSize: DEFAULT_MAX_LENGTH,
noPropagation: true,
})
}
// Deprecated: Use NewTFramedTransportConf instead.
func NewTFramedTransportMaxLength(transport TTransport, maxLength uint32) *TFramedTransport {
return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: maxLength}
return NewTFramedTransportConf(transport, &TConfiguration{
MaxFrameSize: int32(maxLength),
noPropagation: true,
})
}
func NewTFramedTransportConf(transport TTransport, conf *TConfiguration) *TFramedTransport {
PropagateTConfiguration(transport, conf)
return &TFramedTransport{
transport: transport,
reader: bufio.NewReader(transport),
cfg: conf,
}
}
func (p *TFramedTransport) Open() error {
@@ -80,89 +128,65 @@ func (p *TFramedTransport) Close() error {
return p.transport.Close()
}
func (p *TFramedTransport) Read(buf []byte) (l int, err error) {
if p.frameSize == 0 {
p.frameSize, err = p.readFrameHeader()
if err != nil {
return
}
func (p *TFramedTransport) Read(buf []byte) (read int, err error) {
read, err = p.readBuf.Read(buf)
if err != io.EOF {
return
}
if p.frameSize < uint32(len(buf)) {
frameSize := p.frameSize
tmp := make([]byte, p.frameSize)
l, err = p.Read(tmp)
copy(buf, tmp)
if err == nil {
// Note: It's important to only return an error when l
// is zero.
// In io.Reader.Read interface, it's perfectly fine to
// return partial data and nil error, which means
// "This is all the data we have right now without
// blocking. If you need the full data, call Read again
// or use io.ReadFull instead".
// Returning partial data with an error actually means
// there's no more data after the partial data just
// returned, which is not true in this case
// (it might be that the other end just haven't written
// them yet).
if l == 0 {
err = NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", frameSize, len(buf)))
}
return
}
// For bytes.Buffer.Read, EOF would only happen when read is zero,
// but still, do a sanity check,
// in case that behavior is changed in a future version of go stdlib.
// When that happens, just return nil error,
// and let the caller call Read again to read the next frame.
if read > 0 {
return read, nil
}
got, err := p.reader.Read(buf)
p.frameSize = p.frameSize - uint32(got)
//sanity check
if p.frameSize < 0 {
return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Negative frame size")
// Reaching here means that the last Read finished the last frame,
// so we need to read the next frame into readBuf now.
if err = p.readFrame(); err != nil {
return read, err
}
return got, NewTTransportExceptionFromError(err)
newRead, err := p.Read(buf[read:])
return read + newRead, err
}
func (p *TFramedTransport) ReadByte() (c byte, err error) {
if p.frameSize == 0 {
p.frameSize, err = p.readFrameHeader()
if err != nil {
return
}
}
if p.frameSize < 1 {
return 0, NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", p.frameSize, 1))
}
c, err = p.reader.ReadByte()
if err == nil {
p.frameSize--
buf := p.buffer[:1]
_, err = p.Read(buf)
if err != nil {
return
}
c = buf[0]
return
}
func (p *TFramedTransport) Write(buf []byte) (int, error) {
n, err := p.buf.Write(buf)
n, err := p.writeBuf.Write(buf)
return n, NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) WriteByte(c byte) error {
return p.buf.WriteByte(c)
return p.writeBuf.WriteByte(c)
}
func (p *TFramedTransport) WriteString(s string) (n int, err error) {
return p.buf.WriteString(s)
return p.writeBuf.WriteString(s)
}
func (p *TFramedTransport) Flush(ctx context.Context) error {
size := p.buf.Len()
size := p.writeBuf.Len()
buf := p.buffer[:4]
binary.BigEndian.PutUint32(buf, uint32(size))
_, err := p.transport.Write(buf)
if err != nil {
p.buf.Truncate(0)
p.writeBuf.Reset()
return NewTTransportExceptionFromError(err)
}
if size > 0 {
if n, err := p.buf.WriteTo(p.transport); err != nil {
print("Error while flushing write buffer of size ", size, " to transport, only wrote ", n, " bytes: ", err.Error(), "\n")
p.buf.Truncate(0)
if _, err := io.Copy(p.transport, &p.writeBuf); err != nil {
p.writeBuf.Reset()
return NewTTransportExceptionFromError(err)
}
}
@@ -170,18 +194,30 @@ func (p *TFramedTransport) Flush(ctx context.Context) error {
return NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) readFrameHeader() (uint32, error) {
func (p *TFramedTransport) readFrame() error {
buf := p.buffer[:4]
if _, err := io.ReadFull(p.reader, buf); err != nil {
return 0, err
return err
}
size := binary.BigEndian.Uint32(buf)
if size < 0 || size > p.maxLength {
return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size))
if size < 0 || size > uint32(p.cfg.GetMaxFrameSize()) {
return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size))
}
return size, nil
_, err := io.CopyN(&p.readBuf, p.reader, int64(size))
return NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) {
return uint64(p.frameSize)
return uint64(p.readBuf.Len())
}
// SetTConfiguration implements TConfigurationSetter.
func (p *TFramedTransport) SetTConfiguration(cfg *TConfiguration) {
PropagateTConfiguration(p.transport, cfg)
p.cfg = cfg
}
var (
_ TConfigurationSetter = (*tFramedTransportFactory)(nil)
_ TConfigurationSetter = (*TFramedTransport)(nil)
)
@@ -44,6 +44,15 @@ func SetHeader(ctx context.Context, key, value string) context.Context {
)
}
// UnsetHeader unsets a previously set header in the context.
func UnsetHeader(ctx context.Context, key string) context.Context {
return context.WithValue(
ctx,
headerKey(key),
nil,
)
}
// GetHeader returns a value of the given header from the context.
func GetHeader(ctx context.Context, key string) (value string, ok bool) {
if v := ctx.Value(headerKey(key)); v != nil {
@@ -21,6 +21,7 @@ package thrift
import (
"context"
"errors"
)
// THeaderProtocol is a thrift protocol that implements THeader:
@@ -34,34 +35,65 @@ type THeaderProtocol struct {
// Will be initialized on first read/write.
protocol TProtocol
cfg *TConfiguration
}
// NewTHeaderProtocol creates a new THeaderProtocol from the underlying
// transport. The passed in transport will be wrapped with THeaderTransport.
// Deprecated: Use NewTHeaderProtocolConf instead.
func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
return newTHeaderProtocolConf(trans, &TConfiguration{
noPropagation: true,
})
}
// NewTHeaderProtocolConf creates a new THeaderProtocol from the underlying
// transport with given TConfiguration.
//
// The passed in transport will be wrapped with THeaderTransport.
//
// Note that THeaderTransport handles frame and zlib by itself,
// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket),
// instead of rich transports like TZlibTransport or TFramedTransport.
func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
t := NewTHeaderTransport(trans)
p, _ := THeaderProtocolDefault.GetProtocol(t)
func NewTHeaderProtocolConf(trans TTransport, conf *TConfiguration) *THeaderProtocol {
return newTHeaderProtocolConf(trans, conf)
}
func newTHeaderProtocolConf(trans TTransport, cfg *TConfiguration) *THeaderProtocol {
t := NewTHeaderTransportConf(trans, cfg)
p, _ := t.cfg.GetTHeaderProtocolID().GetProtocol(t)
PropagateTConfiguration(p, cfg)
return &THeaderProtocol{
transport: t,
protocol: p,
cfg: cfg,
}
}
type tHeaderProtocolFactory struct{}
func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return NewTHeaderProtocol(trans)
type tHeaderProtocolFactory struct {
cfg *TConfiguration
}
// NewTHeaderProtocolFactory creates a factory for THeader.
//
// It's a wrapper for NewTHeaderProtocol
func (f tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return newTHeaderProtocolConf(trans, f.cfg)
}
func (f *tHeaderProtocolFactory) SetTConfiguration(cfg *TConfiguration) {
f.cfg = cfg
}
// Deprecated: Use NewTHeaderProtocolFactoryConf instead.
func NewTHeaderProtocolFactory() TProtocolFactory {
return tHeaderProtocolFactory{}
return NewTHeaderProtocolFactoryConf(&TConfiguration{
noPropagation: true,
})
}
// NewTHeaderProtocolFactoryConf creates a factory for THeader with given
// TConfiguration.
func NewTHeaderProtocolFactoryConf(conf *TConfiguration) TProtocolFactory {
return tHeaderProtocolFactory{
cfg: conf,
}
}
// Transport returns the underlying transport.
@@ -95,211 +127,225 @@ func (p *THeaderProtocol) Flush(ctx context.Context) error {
return p.transport.Flush(ctx)
}
func (p *THeaderProtocol) WriteMessageBegin(name string, typeID TMessageType, seqID int32) error {
func (p *THeaderProtocol) WriteMessageBegin(ctx context.Context, name string, typeID TMessageType, seqID int32) error {
newProto, err := p.transport.Protocol().GetProtocol(p.transport)
if err != nil {
return err
}
PropagateTConfiguration(newProto, p.cfg)
p.protocol = newProto
p.transport.SequenceID = seqID
return p.protocol.WriteMessageBegin(name, typeID, seqID)
return p.protocol.WriteMessageBegin(ctx, name, typeID, seqID)
}
func (p *THeaderProtocol) WriteMessageEnd() error {
if err := p.protocol.WriteMessageEnd(); err != nil {
func (p *THeaderProtocol) WriteMessageEnd(ctx context.Context) error {
if err := p.protocol.WriteMessageEnd(ctx); err != nil {
return err
}
return p.transport.Flush(context.Background())
return p.transport.Flush(ctx)
}
func (p *THeaderProtocol) WriteStructBegin(name string) error {
return p.protocol.WriteStructBegin(name)
func (p *THeaderProtocol) WriteStructBegin(ctx context.Context, name string) error {
return p.protocol.WriteStructBegin(ctx, name)
}
func (p *THeaderProtocol) WriteStructEnd() error {
return p.protocol.WriteStructEnd()
func (p *THeaderProtocol) WriteStructEnd(ctx context.Context) error {
return p.protocol.WriteStructEnd(ctx)
}
func (p *THeaderProtocol) WriteFieldBegin(name string, typeID TType, id int16) error {
return p.protocol.WriteFieldBegin(name, typeID, id)
func (p *THeaderProtocol) WriteFieldBegin(ctx context.Context, name string, typeID TType, id int16) error {
return p.protocol.WriteFieldBegin(ctx, name, typeID, id)
}
func (p *THeaderProtocol) WriteFieldEnd() error {
return p.protocol.WriteFieldEnd()
func (p *THeaderProtocol) WriteFieldEnd(ctx context.Context) error {
return p.protocol.WriteFieldEnd(ctx)
}
func (p *THeaderProtocol) WriteFieldStop() error {
return p.protocol.WriteFieldStop()
func (p *THeaderProtocol) WriteFieldStop(ctx context.Context) error {
return p.protocol.WriteFieldStop(ctx)
}
func (p *THeaderProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
return p.protocol.WriteMapBegin(keyType, valueType, size)
func (p *THeaderProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
return p.protocol.WriteMapBegin(ctx, keyType, valueType, size)
}
func (p *THeaderProtocol) WriteMapEnd() error {
return p.protocol.WriteMapEnd()
func (p *THeaderProtocol) WriteMapEnd(ctx context.Context) error {
return p.protocol.WriteMapEnd(ctx)
}
func (p *THeaderProtocol) WriteListBegin(elemType TType, size int) error {
return p.protocol.WriteListBegin(elemType, size)
func (p *THeaderProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
return p.protocol.WriteListBegin(ctx, elemType, size)
}
func (p *THeaderProtocol) WriteListEnd() error {
return p.protocol.WriteListEnd()
func (p *THeaderProtocol) WriteListEnd(ctx context.Context) error {
return p.protocol.WriteListEnd(ctx)
}
func (p *THeaderProtocol) WriteSetBegin(elemType TType, size int) error {
return p.protocol.WriteSetBegin(elemType, size)
func (p *THeaderProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
return p.protocol.WriteSetBegin(ctx, elemType, size)
}
func (p *THeaderProtocol) WriteSetEnd() error {
return p.protocol.WriteSetEnd()
func (p *THeaderProtocol) WriteSetEnd(ctx context.Context) error {
return p.protocol.WriteSetEnd(ctx)
}
func (p *THeaderProtocol) WriteBool(value bool) error {
return p.protocol.WriteBool(value)
func (p *THeaderProtocol) WriteBool(ctx context.Context, value bool) error {
return p.protocol.WriteBool(ctx, value)
}
func (p *THeaderProtocol) WriteByte(value int8) error {
return p.protocol.WriteByte(value)
func (p *THeaderProtocol) WriteByte(ctx context.Context, value int8) error {
return p.protocol.WriteByte(ctx, value)
}
func (p *THeaderProtocol) WriteI16(value int16) error {
return p.protocol.WriteI16(value)
func (p *THeaderProtocol) WriteI16(ctx context.Context, value int16) error {
return p.protocol.WriteI16(ctx, value)
}
func (p *THeaderProtocol) WriteI32(value int32) error {
return p.protocol.WriteI32(value)
func (p *THeaderProtocol) WriteI32(ctx context.Context, value int32) error {
return p.protocol.WriteI32(ctx, value)
}
func (p *THeaderProtocol) WriteI64(value int64) error {
return p.protocol.WriteI64(value)
func (p *THeaderProtocol) WriteI64(ctx context.Context, value int64) error {
return p.protocol.WriteI64(ctx, value)
}
func (p *THeaderProtocol) WriteDouble(value float64) error {
return p.protocol.WriteDouble(value)
func (p *THeaderProtocol) WriteDouble(ctx context.Context, value float64) error {
return p.protocol.WriteDouble(ctx, value)
}
func (p *THeaderProtocol) WriteString(value string) error {
return p.protocol.WriteString(value)
func (p *THeaderProtocol) WriteString(ctx context.Context, value string) error {
return p.protocol.WriteString(ctx, value)
}
func (p *THeaderProtocol) WriteBinary(value []byte) error {
return p.protocol.WriteBinary(value)
func (p *THeaderProtocol) WriteBinary(ctx context.Context, value []byte) error {
return p.protocol.WriteBinary(ctx, value)
}
// ReadFrame calls underlying THeaderTransport's ReadFrame function.
func (p *THeaderProtocol) ReadFrame() error {
return p.transport.ReadFrame()
func (p *THeaderProtocol) ReadFrame(ctx context.Context) error {
return p.transport.ReadFrame(ctx)
}
func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID TMessageType, seqID int32, err error) {
if err = p.transport.ReadFrame(); err != nil {
func (p *THeaderProtocol) ReadMessageBegin(ctx context.Context) (name string, typeID TMessageType, seqID int32, err error) {
if err = p.transport.ReadFrame(ctx); err != nil {
return
}
var newProto TProtocol
newProto, err = p.transport.Protocol().GetProtocol(p.transport)
if err != nil {
tAppExc, ok := err.(TApplicationException)
if !ok {
var tAppExc TApplicationException
if !errors.As(err, &tAppExc) {
return
}
if e := p.protocol.WriteMessageBegin("", EXCEPTION, seqID); e != nil {
if e := p.protocol.WriteMessageBegin(ctx, "", EXCEPTION, seqID); e != nil {
return
}
if e := tAppExc.Write(p.protocol); e != nil {
if e := tAppExc.Write(ctx, p.protocol); e != nil {
return
}
if e := p.protocol.WriteMessageEnd(); e != nil {
if e := p.protocol.WriteMessageEnd(ctx); e != nil {
return
}
if e := p.transport.Flush(context.Background()); e != nil {
if e := p.transport.Flush(ctx); e != nil {
return
}
return
}
PropagateTConfiguration(newProto, p.cfg)
p.protocol = newProto
return p.protocol.ReadMessageBegin()
return p.protocol.ReadMessageBegin(ctx)
}
func (p *THeaderProtocol) ReadMessageEnd() error {
return p.protocol.ReadMessageEnd()
func (p *THeaderProtocol) ReadMessageEnd(ctx context.Context) error {
return p.protocol.ReadMessageEnd(ctx)
}
func (p *THeaderProtocol) ReadStructBegin() (name string, err error) {
return p.protocol.ReadStructBegin()
func (p *THeaderProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
return p.protocol.ReadStructBegin(ctx)
}
func (p *THeaderProtocol) ReadStructEnd() error {
return p.protocol.ReadStructEnd()
func (p *THeaderProtocol) ReadStructEnd(ctx context.Context) error {
return p.protocol.ReadStructEnd(ctx)
}
func (p *THeaderProtocol) ReadFieldBegin() (name string, typeID TType, id int16, err error) {
return p.protocol.ReadFieldBegin()
func (p *THeaderProtocol) ReadFieldBegin(ctx context.Context) (name string, typeID TType, id int16, err error) {
return p.protocol.ReadFieldBegin(ctx)
}
func (p *THeaderProtocol) ReadFieldEnd() error {
return p.protocol.ReadFieldEnd()
func (p *THeaderProtocol) ReadFieldEnd(ctx context.Context) error {
return p.protocol.ReadFieldEnd(ctx)
}
func (p *THeaderProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
return p.protocol.ReadMapBegin()
func (p *THeaderProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) {
return p.protocol.ReadMapBegin(ctx)
}
func (p *THeaderProtocol) ReadMapEnd() error {
return p.protocol.ReadMapEnd()
func (p *THeaderProtocol) ReadMapEnd(ctx context.Context) error {
return p.protocol.ReadMapEnd(ctx)
}
func (p *THeaderProtocol) ReadListBegin() (elemType TType, size int, err error) {
return p.protocol.ReadListBegin()
func (p *THeaderProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
return p.protocol.ReadListBegin(ctx)
}
func (p *THeaderProtocol) ReadListEnd() error {
return p.protocol.ReadListEnd()
func (p *THeaderProtocol) ReadListEnd(ctx context.Context) error {
return p.protocol.ReadListEnd(ctx)
}
func (p *THeaderProtocol) ReadSetBegin() (elemType TType, size int, err error) {
return p.protocol.ReadSetBegin()
func (p *THeaderProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
return p.protocol.ReadSetBegin(ctx)
}
func (p *THeaderProtocol) ReadSetEnd() error {
return p.protocol.ReadSetEnd()
func (p *THeaderProtocol) ReadSetEnd(ctx context.Context) error {
return p.protocol.ReadSetEnd(ctx)
}
func (p *THeaderProtocol) ReadBool() (value bool, err error) {
return p.protocol.ReadBool()
func (p *THeaderProtocol) ReadBool(ctx context.Context) (value bool, err error) {
return p.protocol.ReadBool(ctx)
}
func (p *THeaderProtocol) ReadByte() (value int8, err error) {
return p.protocol.ReadByte()
func (p *THeaderProtocol) ReadByte(ctx context.Context) (value int8, err error) {
return p.protocol.ReadByte(ctx)
}
func (p *THeaderProtocol) ReadI16() (value int16, err error) {
return p.protocol.ReadI16()
func (p *THeaderProtocol) ReadI16(ctx context.Context) (value int16, err error) {
return p.protocol.ReadI16(ctx)
}
func (p *THeaderProtocol) ReadI32() (value int32, err error) {
return p.protocol.ReadI32()
func (p *THeaderProtocol) ReadI32(ctx context.Context) (value int32, err error) {
return p.protocol.ReadI32(ctx)
}
func (p *THeaderProtocol) ReadI64() (value int64, err error) {
return p.protocol.ReadI64()
func (p *THeaderProtocol) ReadI64(ctx context.Context) (value int64, err error) {
return p.protocol.ReadI64(ctx)
}
func (p *THeaderProtocol) ReadDouble() (value float64, err error) {
return p.protocol.ReadDouble()
func (p *THeaderProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
return p.protocol.ReadDouble(ctx)
}
func (p *THeaderProtocol) ReadString() (value string, err error) {
return p.protocol.ReadString()
func (p *THeaderProtocol) ReadString(ctx context.Context) (value string, err error) {
return p.protocol.ReadString(ctx)
}
func (p *THeaderProtocol) ReadBinary() (value []byte, err error) {
return p.protocol.ReadBinary()
func (p *THeaderProtocol) ReadBinary(ctx context.Context) (value []byte, err error) {
return p.protocol.ReadBinary(ctx)
}
func (p *THeaderProtocol) Skip(fieldType TType) error {
return p.protocol.Skip(fieldType)
func (p *THeaderProtocol) Skip(ctx context.Context, fieldType TType) error {
return p.protocol.Skip(ctx, fieldType)
}
// SetTConfiguration implements TConfigurationSetter.
func (p *THeaderProtocol) SetTConfiguration(cfg *TConfiguration) {
PropagateTConfiguration(p.transport, cfg)
PropagateTConfiguration(p.protocol, cfg)
p.cfg = cfg
}
var (
_ TConfigurationSetter = (*tHeaderProtocolFactory)(nil)
_ TConfigurationSetter = (*THeaderProtocol)(nil)
)
@@ -75,6 +75,15 @@ const (
THeaderProtocolDefault = THeaderProtocolBinary
)
// Declared globally to avoid repetitive allocations, not really used.
var globalMemoryBuffer = NewTMemoryBuffer()
// Validate checks whether the THeaderProtocolID is a valid/supported one.
func (id THeaderProtocolID) Validate() error {
_, err := id.GetProtocol(globalMemoryBuffer)
return err
}
// GetProtocol gets the corresponding TProtocol from the wrapped protocol id.
func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
switch id {
@@ -84,7 +93,7 @@ func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
fmt.Sprintf("THeader protocol id %d not supported", id),
)
case THeaderProtocolBinary:
return NewTBinaryProtocolFactoryDefault().GetProtocol(trans), nil
return NewTBinaryProtocolTransport(trans), nil
case THeaderProtocolCompact:
return NewTCompactProtocol(trans), nil
}
@@ -93,11 +102,12 @@ func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
// THeaderTransformID defines the numeric id of the transform used.
type THeaderTransformID int32
// THeaderTransformID values
// THeaderTransformID values.
//
// Values not defined here are not currently supported, namely HMAC and Snappy.
const (
TransformNone THeaderTransformID = iota // 0, no special handling
TransformZlib // 1, zlib
// Rest of the values are not currently supported, namely HMAC and Snappy.
)
var supportedTransformIDs = map[THeaderTransformID]bool{
@@ -255,6 +265,7 @@ type THeaderTransport struct {
clientType clientType
protocolID THeaderProtocolID
cfg *TConfiguration
// buffer is used in the following scenarios to avoid repetitive
// allocations, while 4 is big enough for all those scenarios:
@@ -266,22 +277,35 @@ type THeaderTransport struct {
var _ TTransport = (*THeaderTransport)(nil)
// NewTHeaderTransport creates THeaderTransport from the underlying transport.
//
// Please note that THeaderTransport handles framing and zlib by itself,
// so the underlying transport should be the raw socket transports (TSocket or TSSLSocket),
// instead of rich transports like TZlibTransport or TFramedTransport.
//
// If trans is already a *THeaderTransport, it will be returned as is.
// Deprecated: Use NewTHeaderTransportConf instead.
func NewTHeaderTransport(trans TTransport) *THeaderTransport {
return NewTHeaderTransportConf(trans, &TConfiguration{
noPropagation: true,
})
}
// NewTHeaderTransportConf creates THeaderTransport from the
// underlying transport, with given TConfiguration attached.
//
// If trans is already a *THeaderTransport, it will be returned as is,
// but with TConfiguration overridden by the value passed in.
//
// The protocol ID in TConfiguration is only useful for client transports.
// For servers,
// the protocol ID will be overridden again to the one set by the client,
// to ensure that servers always speak the same dialect as the client.
func NewTHeaderTransportConf(trans TTransport, conf *TConfiguration) *THeaderTransport {
if ht, ok := trans.(*THeaderTransport); ok {
ht.SetTConfiguration(conf)
return ht
}
PropagateTConfiguration(trans, conf)
return &THeaderTransport{
transport: trans,
reader: bufio.NewReader(trans),
writeHeaders: make(THeaderMap),
protocolID: THeaderProtocolDefault,
protocolID: conf.GetTHeaderProtocolID(),
cfg: conf,
}
}
@@ -297,18 +321,34 @@ func (t *THeaderTransport) IsOpen() bool {
// ReadFrame tries to read the frame header, guess the client type, and handle
// unframed clients.
func (t *THeaderTransport) ReadFrame() error {
func (t *THeaderTransport) ReadFrame(ctx context.Context) error {
if !t.needReadFrame() {
// No need to read frame, skipping.
return nil
}
// Peek and handle the first 32 bits.
// They could either be the length field of a framed message,
// or the first bytes of an unframed message.
buf, err := t.reader.Peek(size32)
var buf []byte
var err error
// This is also usually the first read from a connection,
// so handle retries around socket timeouts.
_, deadlineSet := ctx.Deadline()
for {
buf, err = t.reader.Peek(size32)
if deadlineSet && isTimeoutError(err) && ctx.Err() == nil {
// This is I/O timeout and we still have time,
// continue trying
continue
}
// For anything else, do not retry
break
}
if err != nil {
return err
}
frameSize := binary.BigEndian.Uint32(buf)
if frameSize&VERSION_MASK == VERSION_1 {
t.clientType = clientUnframedBinary
@@ -321,7 +361,7 @@ func (t *THeaderTransport) ReadFrame() error {
// At this point it should be a framed message,
// sanity check on frameSize then discard the peeked part.
if frameSize > THeaderMaxFrameSize {
if frameSize > THeaderMaxFrameSize || frameSize > uint32(t.cfg.GetMaxFrameSize()) {
return NewTProtocolExceptionWithType(
SIZE_LIMIT,
errors.New("frame too large"),
@@ -330,10 +370,7 @@ func (t *THeaderTransport) ReadFrame() error {
t.reader.Discard(size32)
// Read the frame fully into frameBuffer.
_, err = io.Copy(
&t.frameBuffer,
io.LimitReader(t.reader, int64(frameSize)),
)
_, err = io.CopyN(&t.frameBuffer, t.reader, int64(frameSize))
if err != nil {
return err
}
@@ -344,7 +381,7 @@ func (t *THeaderTransport) ReadFrame() error {
version := binary.BigEndian.Uint32(buf)
if version&THeaderHeaderMask == THeaderHeaderMagic {
t.clientType = clientHeaders
return t.parseHeaders(frameSize)
return t.parseHeaders(ctx, frameSize)
}
if version&VERSION_MASK == VERSION_1 {
t.clientType = clientFramedBinary
@@ -374,7 +411,7 @@ func (t *THeaderTransport) endOfFrame() error {
return t.frameReader.Close()
}
func (t *THeaderTransport) parseHeaders(frameSize uint32) error {
func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) error {
if t.clientType != clientHeaders {
return nil
}
@@ -395,11 +432,12 @@ func (t *THeaderTransport) parseHeaders(frameSize uint32) error {
)
}
headerBuf := NewTMemoryBuffer()
_, err = io.Copy(headerBuf, io.LimitReader(&t.frameBuffer, headerLength))
_, err = io.CopyN(headerBuf, &t.frameBuffer, headerLength)
if err != nil {
return err
}
hp := NewTCompactProtocol(headerBuf)
hp.SetTConfiguration(t.cfg)
// At this point the header is already read into headerBuf,
// and t.frameBuffer starts from the actual payload.
@@ -408,6 +446,7 @@ func (t *THeaderTransport) parseHeaders(frameSize uint32) error {
return err
}
t.protocolID = THeaderProtocolID(protoID)
var transformCount int32
transformCount, err = hp.readVarint32()
if err != nil {
@@ -442,7 +481,7 @@ func (t *THeaderTransport) parseHeaders(frameSize uint32) error {
headers := make(THeaderMap)
for {
infoType, err := hp.readVarint32()
if err == io.EOF {
if errors.Is(err, io.EOF) {
break
}
if err != nil {
@@ -454,11 +493,11 @@ func (t *THeaderTransport) parseHeaders(frameSize uint32) error {
return err
}
for i := 0; i < int(count); i++ {
key, err := hp.ReadString()
key, err := hp.ReadString(ctx)
if err != nil {
return err
}
value, err := hp.ReadString()
value, err := hp.ReadString(ctx)
if err != nil {
return err
}
@@ -488,21 +527,37 @@ func (t *THeaderTransport) needReadFrame() bool {
}
func (t *THeaderTransport) Read(p []byte) (read int, err error) {
err = t.ReadFrame()
// Here using context.Background instead of a context passed in is safe.
// First is that there's no way to pass context into this function.
// Then, 99% of the case when calling this Read frame is already read
// into frameReader. ReadFrame here is more of preventing bugs that
// didn't call ReadFrame before calling Read.
err = t.ReadFrame(context.Background())
if err != nil {
return
}
if t.frameReader != nil {
read, err = t.frameReader.Read(p)
if err == io.EOF {
if err == nil && t.frameBuffer.Len() <= 0 {
// the last Read finished the frame, do endOfFrame
// handling here.
err = t.endOfFrame()
} else if err == io.EOF {
err = t.endOfFrame()
if err != nil {
return
}
if read < len(p) {
var nextRead int
nextRead, err = t.Read(p[read:])
read += nextRead
if read == 0 {
// Try to read the next frame when we hit EOF
// (end of frame) immediately.
// When we got here, it means the last read
// finished the previous frame, but didn't
// do endOfFrame handling yet.
// We have to read the next frame here,
// as otherwise we would return 0 and nil,
// which is a case not handled well by most
// protocol implementations.
return t.Read(p)
}
}
return
@@ -534,6 +589,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
case clientHeaders:
headers := NewTMemoryBuffer()
hp := NewTCompactProtocol(headers)
hp.SetTConfiguration(t.cfg)
if _, err := hp.writeVarint32(int32(t.protocolID)); err != nil {
return NewTTransportExceptionFromError(err)
}
@@ -553,10 +609,10 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
return NewTTransportExceptionFromError(err)
}
for key, value := range t.writeHeaders {
if err := hp.WriteString(key); err != nil {
if err := hp.WriteString(ctx, key); err != nil {
return NewTTransportExceptionFromError(err)
}
if err := hp.WriteString(value); err != nil {
if err := hp.WriteString(ctx, value); err != nil {
return NewTTransportExceptionFromError(err)
}
}
@@ -696,17 +752,37 @@ func (t *THeaderTransport) isFramed() bool {
}
}
// SetTConfiguration implements TConfigurationSetter.
func (t *THeaderTransport) SetTConfiguration(cfg *TConfiguration) {
PropagateTConfiguration(t.transport, cfg)
t.cfg = cfg
}
// THeaderTransportFactory is a TTransportFactory implementation to create
// THeaderTransport.
//
// It also implements TConfigurationSetter.
type THeaderTransportFactory struct {
// The underlying factory, could be nil.
Factory TTransportFactory
cfg *TConfiguration
}
// NewTHeaderTransportFactory creates a new *THeaderTransportFactory.
// Deprecated: Use NewTHeaderTransportFactoryConf instead.
func NewTHeaderTransportFactory(factory TTransportFactory) TTransportFactory {
return NewTHeaderTransportFactoryConf(factory, &TConfiguration{
noPropagation: true,
})
}
// NewTHeaderTransportFactoryConf creates a new *THeaderTransportFactory with
// the given *TConfiguration.
func NewTHeaderTransportFactoryConf(factory TTransportFactory, conf *TConfiguration) TTransportFactory {
return &THeaderTransportFactory{
Factory: factory,
cfg: conf,
}
}
@@ -717,7 +793,18 @@ func (f *THeaderTransportFactory) GetTransport(trans TTransport) (TTransport, er
if err != nil {
return nil, err
}
return NewTHeaderTransport(t), nil
return NewTHeaderTransportConf(t, f.cfg), nil
}
return NewTHeaderTransport(trans), nil
return NewTHeaderTransportConf(trans, f.cfg), nil
}
// SetTConfiguration implements TConfigurationSetter.
func (f *THeaderTransportFactory) SetTConfiguration(cfg *TConfiguration) {
PropagateTConfiguration(f.Factory, f.cfg)
f.cfg = cfg
}
var (
_ TConfigurationSetter = (*THeaderTransportFactory)(nil)
_ TConfigurationSetter = (*THeaderTransport)(nil)
)
@@ -22,6 +22,7 @@ package thrift
import (
"bytes"
"context"
"errors"
"io"
"io/ioutil"
"net/http"
@@ -159,26 +160,37 @@ func (p *THttpClient) Read(buf []byte) (int, error) {
return 0, NewTTransportException(NOT_OPEN, "Response buffer is empty, no request.")
}
n, err := p.response.Body.Read(buf)
if n > 0 && (err == nil || err == io.EOF) {
if n > 0 && (err == nil || errors.Is(err, io.EOF)) {
return n, nil
}
return n, NewTTransportExceptionFromError(err)
}
func (p *THttpClient) ReadByte() (c byte, err error) {
if p.response == nil {
return 0, NewTTransportException(NOT_OPEN, "Response buffer is empty, no request.")
}
return readByte(p.response.Body)
}
func (p *THttpClient) Write(buf []byte) (int, error) {
n, err := p.requestBuffer.Write(buf)
return n, err
if p.requestBuffer == nil {
return 0, NewTTransportException(NOT_OPEN, "Request buffer is nil, connection may have been closed.")
}
return p.requestBuffer.Write(buf)
}
func (p *THttpClient) WriteByte(c byte) error {
if p.requestBuffer == nil {
return NewTTransportException(NOT_OPEN, "Request buffer is nil, connection may have been closed.")
}
return p.requestBuffer.WriteByte(c)
}
func (p *THttpClient) WriteString(s string) (n int, err error) {
if p.requestBuffer == nil {
return 0, NewTTransportException(NOT_OPEN, "Request buffer is nil, connection may have been closed.")
}
return p.requestBuffer.WriteString(s)
}
@@ -186,7 +198,11 @@ func (p *THttpClient) Flush(ctx context.Context) error {
// Close any previous response body to avoid leaking connections.
p.closeResponse()
req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer)
// Give up the ownership of the current request buffer to http request,
// and create a new buffer for the next request.
buf := p.requestBuffer
p.requestBuffer = new(bytes.Buffer)
req, err := http.NewRequest("POST", p.url.String(), buf)
if err != nil {
return NewTTransportExceptionFromError(err)
}
@@ -218,7 +234,7 @@ func (p *THttpClient) RemainingBytes() (num_bytes uint64) {
}
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
return maxSize // the truth is, we just don't know unless framed is used
}
// Deprecated: Use NewTHttpClientTransportFactory instead.
@@ -24,6 +24,7 @@ import (
"io"
"net/http"
"strings"
"sync"
)
// NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function
@@ -40,14 +41,24 @@ func NewThriftHandlerFunc(processor TProcessor,
// gz transparently compresses the HTTP response if the client supports it.
func gz(handler http.HandlerFunc) http.HandlerFunc {
sp := &sync.Pool{
New: func() interface{} {
return gzip.NewWriter(nil)
},
}
return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
handler(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(w)
defer gz.Close()
gz := sp.Get().(*gzip.Writer)
gz.Reset(w)
defer func() {
_ = gz.Close()
sp.Put(gz)
}()
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
handler(gzw, r)
}
@@ -210,5 +210,13 @@ func (p *StreamTransport) WriteString(s string) (n int, err error) {
func (p *StreamTransport) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
return maxSize // the truth is, we just don't know unless framed is used
}
// SetTConfiguration implements TConfigurationSetter for propagation.
func (p *StreamTransport) SetTConfiguration(conf *TConfiguration) {
PropagateTConfiguration(p.Reader, conf)
PropagateTConfiguration(p.Writer, conf)
}
var _ TConfigurationSetter = (*StreamTransport)(nil)
@@ -41,8 +41,8 @@ type TJSONProtocol struct {
// Constructor
func NewTJSONProtocol(t TTransport) *TJSONProtocol {
v := &TJSONProtocol{TSimpleJSONProtocol: NewTSimpleJSONProtocol(t)}
v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
v.parseContextStack.push(_CONTEXT_IN_TOPLEVEL)
v.dumpContext.push(_CONTEXT_IN_TOPLEVEL)
return v
}
@@ -57,43 +57,43 @@ func NewTJSONProtocolFactory() *TJSONProtocolFactory {
return &TJSONProtocolFactory{}
}
func (p *TJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
func (p *TJSONProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error {
p.resetContextStack() // THRIFT-3735
if e := p.OutputListBegin(); e != nil {
return e
}
if e := p.WriteI32(THRIFT_JSON_PROTOCOL_VERSION); e != nil {
if e := p.WriteI32(ctx, THRIFT_JSON_PROTOCOL_VERSION); e != nil {
return e
}
if e := p.WriteString(name); e != nil {
if e := p.WriteString(ctx, name); e != nil {
return e
}
if e := p.WriteByte(int8(typeId)); e != nil {
if e := p.WriteByte(ctx, int8(typeId)); e != nil {
return e
}
if e := p.WriteI32(seqId); e != nil {
if e := p.WriteI32(ctx, seqId); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteMessageEnd() error {
func (p *TJSONProtocol) WriteMessageEnd(ctx context.Context) error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteStructBegin(name string) error {
func (p *TJSONProtocol) WriteStructBegin(ctx context.Context, name string) error {
if e := p.OutputObjectBegin(); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteStructEnd() error {
func (p *TJSONProtocol) WriteStructEnd(ctx context.Context) error {
return p.OutputObjectEnd()
}
func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
if e := p.WriteI16(id); e != nil {
func (p *TJSONProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
if e := p.WriteI16(ctx, id); e != nil {
return e
}
if e := p.OutputObjectBegin(); e != nil {
@@ -103,19 +103,19 @@ func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) err
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
if e := p.WriteString(ctx, s); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteFieldEnd() error {
func (p *TJSONProtocol) WriteFieldEnd(ctx context.Context) error {
return p.OutputObjectEnd()
}
func (p *TJSONProtocol) WriteFieldStop() error { return nil }
func (p *TJSONProtocol) WriteFieldStop(ctx context.Context) error { return nil }
func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
func (p *TJSONProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
@@ -123,77 +123,77 @@ func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
if e := p.WriteString(ctx, s); e != nil {
return e
}
s, e1 = p.TypeIdToString(valueType)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
if e := p.WriteString(ctx, s); e != nil {
return e
}
if e := p.WriteI64(int64(size)); e != nil {
if e := p.WriteI64(ctx, int64(size)); e != nil {
return e
}
return p.OutputObjectBegin()
}
func (p *TJSONProtocol) WriteMapEnd() error {
func (p *TJSONProtocol) WriteMapEnd(ctx context.Context) error {
if e := p.OutputObjectEnd(); e != nil {
return e
}
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteListBegin(elemType TType, size int) error {
func (p *TJSONProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
func (p *TJSONProtocol) WriteListEnd() error {
func (p *TJSONProtocol) WriteListEnd(ctx context.Context) error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteSetBegin(elemType TType, size int) error {
func (p *TJSONProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
func (p *TJSONProtocol) WriteSetEnd() error {
func (p *TJSONProtocol) WriteSetEnd(ctx context.Context) error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteBool(b bool) error {
func (p *TJSONProtocol) WriteBool(ctx context.Context, b bool) error {
if b {
return p.WriteI32(1)
return p.WriteI32(ctx, 1)
}
return p.WriteI32(0)
return p.WriteI32(ctx, 0)
}
func (p *TJSONProtocol) WriteByte(b int8) error {
return p.WriteI32(int32(b))
func (p *TJSONProtocol) WriteByte(ctx context.Context, b int8) error {
return p.WriteI32(ctx, int32(b))
}
func (p *TJSONProtocol) WriteI16(v int16) error {
return p.WriteI32(int32(v))
func (p *TJSONProtocol) WriteI16(ctx context.Context, v int16) error {
return p.WriteI32(ctx, int32(v))
}
func (p *TJSONProtocol) WriteI32(v int32) error {
func (p *TJSONProtocol) WriteI32(ctx context.Context, v int32) error {
return p.OutputI64(int64(v))
}
func (p *TJSONProtocol) WriteI64(v int64) error {
func (p *TJSONProtocol) WriteI64(ctx context.Context, v int64) error {
return p.OutputI64(int64(v))
}
func (p *TJSONProtocol) WriteDouble(v float64) error {
func (p *TJSONProtocol) WriteDouble(ctx context.Context, v float64) error {
return p.OutputF64(v)
}
func (p *TJSONProtocol) WriteString(v string) error {
func (p *TJSONProtocol) WriteString(ctx context.Context, v string) error {
return p.OutputString(v)
}
func (p *TJSONProtocol) WriteBinary(v []byte) error {
func (p *TJSONProtocol) WriteBinary(ctx context.Context, v []byte) error {
// JSON library only takes in a string,
// not an arbitrary byte array, to ensure bytes are transmitted
// efficiently we must convert this into a valid JSON string
@@ -219,12 +219,12 @@ func (p *TJSONProtocol) WriteBinary(v []byte) error {
}
// Reading methods.
func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
func (p *TJSONProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {
p.resetContextStack() // THRIFT-3735
if isNull, err := p.ParseListBegin(); isNull || err != nil {
return name, typeId, seqId, err
}
version, err := p.ReadI32()
version, err := p.ReadI32(ctx)
if err != nil {
return name, typeId, seqId, err
}
@@ -233,47 +233,47 @@ func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, se
return name, typeId, seqId, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
if name, err = p.ReadString(); err != nil {
if name, err = p.ReadString(ctx); err != nil {
return name, typeId, seqId, err
}
bTypeId, err := p.ReadByte()
bTypeId, err := p.ReadByte(ctx)
typeId = TMessageType(bTypeId)
if err != nil {
return name, typeId, seqId, err
}
if seqId, err = p.ReadI32(); err != nil {
if seqId, err = p.ReadI32(ctx); err != nil {
return name, typeId, seqId, err
}
return name, typeId, seqId, nil
}
func (p *TJSONProtocol) ReadMessageEnd() error {
func (p *TJSONProtocol) ReadMessageEnd(ctx context.Context) error {
err := p.ParseListEnd()
return err
}
func (p *TJSONProtocol) ReadStructBegin() (name string, err error) {
func (p *TJSONProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
_, err = p.ParseObjectStart()
return "", err
}
func (p *TJSONProtocol) ReadStructEnd() error {
func (p *TJSONProtocol) ReadStructEnd(ctx context.Context) error {
return p.ParseObjectEnd()
}
func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
func (p *TJSONProtocol) ReadFieldBegin(ctx context.Context) (string, TType, int16, error) {
b, _ := p.reader.Peek(1)
if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] {
return "", STOP, -1, nil
}
fieldId, err := p.ReadI16()
fieldId, err := p.ReadI16(ctx)
if err != nil {
return "", STOP, fieldId, err
}
if _, err = p.ParseObjectStart(); err != nil {
return "", STOP, fieldId, err
}
sType, err := p.ReadString()
sType, err := p.ReadString(ctx)
if err != nil {
return "", STOP, fieldId, err
}
@@ -281,17 +281,17 @@ func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
return "", fType, fieldId, err
}
func (p *TJSONProtocol) ReadFieldEnd() error {
func (p *TJSONProtocol) ReadFieldEnd(ctx context.Context) error {
return p.ParseObjectEnd()
}
func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) {
func (p *TJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, VOID, 0, e
}
// read keyType
sKeyType, e := p.ReadString()
sKeyType, e := p.ReadString(ctx)
if e != nil {
return keyType, valueType, size, e
}
@@ -301,7 +301,7 @@ func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int
}
// read valueType
sValueType, e := p.ReadString()
sValueType, e := p.ReadString(ctx)
if e != nil {
return keyType, valueType, size, e
}
@@ -311,7 +311,7 @@ func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int
}
// read size
iSize, e := p.ReadI64()
iSize, e := p.ReadI64(ctx)
if e != nil {
return keyType, valueType, size, e
}
@@ -321,7 +321,7 @@ func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int
return keyType, valueType, size, e
}
func (p *TJSONProtocol) ReadMapEnd() error {
func (p *TJSONProtocol) ReadMapEnd(ctx context.Context) error {
e := p.ParseObjectEnd()
if e != nil {
return e
@@ -329,53 +329,53 @@ func (p *TJSONProtocol) ReadMapEnd() error {
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadListBegin() (elemType TType, size int, e error) {
func (p *TJSONProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
func (p *TJSONProtocol) ReadListEnd() error {
func (p *TJSONProtocol) ReadListEnd(ctx context.Context) error {
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) {
func (p *TJSONProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
func (p *TJSONProtocol) ReadSetEnd() error {
func (p *TJSONProtocol) ReadSetEnd(ctx context.Context) error {
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadBool() (bool, error) {
value, err := p.ReadI32()
func (p *TJSONProtocol) ReadBool(ctx context.Context) (bool, error) {
value, err := p.ReadI32(ctx)
return (value != 0), err
}
func (p *TJSONProtocol) ReadByte() (int8, error) {
v, err := p.ReadI64()
func (p *TJSONProtocol) ReadByte(ctx context.Context) (int8, error) {
v, err := p.ReadI64(ctx)
return int8(v), err
}
func (p *TJSONProtocol) ReadI16() (int16, error) {
v, err := p.ReadI64()
func (p *TJSONProtocol) ReadI16(ctx context.Context) (int16, error) {
v, err := p.ReadI64(ctx)
return int16(v), err
}
func (p *TJSONProtocol) ReadI32() (int32, error) {
v, err := p.ReadI64()
func (p *TJSONProtocol) ReadI32(ctx context.Context) (int32, error) {
v, err := p.ReadI64(ctx)
return int32(v), err
}
func (p *TJSONProtocol) ReadI64() (int64, error) {
func (p *TJSONProtocol) ReadI64(ctx context.Context) (int64, error) {
v, _, err := p.ParseI64()
return v, err
}
func (p *TJSONProtocol) ReadDouble() (float64, error) {
func (p *TJSONProtocol) ReadDouble(ctx context.Context) (float64, error) {
v, _, err := p.ParseF64()
return v, err
}
func (p *TJSONProtocol) ReadString() (string, error) {
func (p *TJSONProtocol) ReadString(ctx context.Context) (string, error) {
var v string
if err := p.ParsePreValue(); err != nil {
return v, err
@@ -405,7 +405,7 @@ func (p *TJSONProtocol) ReadString() (string, error) {
return v, p.ParsePostValue()
}
func (p *TJSONProtocol) ReadBinary() ([]byte, error) {
func (p *TJSONProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
var v []byte
if err := p.ParsePreValue(); err != nil {
return nil, err
@@ -444,8 +444,8 @@ func (p *TJSONProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(err)
}
func (p *TJSONProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
func (p *TJSONProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
return SkipDefaultDepth(ctx, p, fieldType)
}
func (p *TJSONProtocol) Transport() TTransport {
@@ -460,10 +460,10 @@ func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error {
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
if e := p.OutputString(s); e != nil {
return e
}
if e := p.WriteI64(int64(size)); e != nil {
if e := p.OutputI64(int64(size)); e != nil {
return e
}
return nil
@@ -473,7 +473,11 @@ func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error)
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
sElemType, err := p.ReadString()
// We don't really use the ctx in ReadString implementation,
// so this is safe for now.
// We might want to add context to ParseElemListBegin if we start to use
// ctx in ReadString implementation in the future.
sElemType, err := p.ReadString(context.Background())
if err != nil {
return VOID, size, err
}
@@ -481,7 +485,7 @@ func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error)
if err != nil {
return elemType, size, err
}
nSize, err2 := p.ReadI64()
nSize, _, err2 := p.ParseI64()
size = int(nSize)
return elemType, size, err2
}
@@ -490,7 +494,11 @@ func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error)
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
sElemType, err := p.ReadString()
// We don't really use the ctx in ReadString implementation,
// so this is safe for now.
// We might want to add context to ParseElemListBegin if we start to use
// ctx in ReadString implementation in the future.
sElemType, err := p.ReadString(context.Background())
if err != nil {
return VOID, size, err
}
@@ -498,7 +506,7 @@ func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error)
if err != nil {
return elemType, size, err
}
nSize, err2 := p.ReadI64()
nSize, _, err2 := p.ParseI64()
size = int(nSize)
return elemType, size, err2
}
@@ -579,3 +587,5 @@ func (p *TJSONProtocol) StringToTypeId(fieldType string) (TType, error) {
e := fmt.Errorf("Unknown type identifier: %s", fieldType)
return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e)
}
var _ TConfigurationSetter = (*TJSONProtocol)(nil)
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"log"
"os"
"testing"
)
// Logger is a simple wrapper of a logging function.
//
// In reality the users might actually use different logging libraries, and they
// are not always compatible with each other.
//
// Logger is meant to be a simple common ground that it's easy to wrap whatever
// logging library they use into.
//
// See https://issues.apache.org/jira/browse/THRIFT-4985 for the design
// discussion behind it.
type Logger func(msg string)
// NopLogger is a Logger implementation that does nothing.
func NopLogger(msg string) {}
// StdLogger wraps stdlib log package into a Logger.
//
// If logger passed in is nil, it will fallback to use stderr and default flags.
func StdLogger(logger *log.Logger) Logger {
if logger == nil {
logger = log.New(os.Stderr, "", log.LstdFlags)
}
return func(msg string) {
logger.Print(msg)
}
}
// TestLogger is a Logger implementation can be used in test codes.
//
// It fails the test when being called.
func TestLogger(tb testing.TB) Logger {
return func(msg string) {
tb.Errorf("logger called with msg: %q", msg)
}
}
func fallbackLogger(logger Logger) Logger {
if logger == nil {
return StdLogger(nil)
}
return logger
}
@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "context"
// ProcessorMiddleware is a function that can be passed to WrapProcessor to wrap the
// TProcessorFunctions for that TProcessor.
//
// Middlewares are passed in the name of the function as set in the processor
// map of the TProcessor.
type ProcessorMiddleware func(name string, next TProcessorFunction) TProcessorFunction
// WrapProcessor takes an existing TProcessor and wraps each of its inner
// TProcessorFunctions with the middlewares passed in and returns it.
//
// Middlewares will be called in the order that they are defined:
//
// 1. Middlewares[0]
// 2. Middlewares[1]
// ...
// N. Middlewares[n]
func WrapProcessor(processor TProcessor, middlewares ...ProcessorMiddleware) TProcessor {
for name, processorFunc := range processor.ProcessorMap() {
wrapped := processorFunc
// Add middlewares in reverse so the first in the list is the outermost.
for i := len(middlewares) - 1; i >= 0; i-- {
wrapped = middlewares[i](name, wrapped)
}
processor.AddToProcessorMap(name, wrapped)
}
return processor
}
// WrappedTProcessorFunction is a convenience struct that implements the
// TProcessorFunction interface that can be used when implementing custom
// Middleware.
type WrappedTProcessorFunction struct {
// Wrapped is called by WrappedTProcessorFunction.Process and should be a
// "wrapped" call to a base TProcessorFunc.Process call.
Wrapped func(ctx context.Context, seqId int32, in, out TProtocol) (bool, TException)
}
// Process implements the TProcessorFunction interface using p.Wrapped.
func (p WrappedTProcessorFunction) Process(ctx context.Context, seqID int32, in, out TProtocol) (bool, TException) {
return p.Wrapped(ctx, seqID, in, out)
}
// verify that WrappedTProcessorFunction implements TProcessorFunction
var (
_ TProcessorFunction = WrappedTProcessorFunction{}
_ TProcessorFunction = (*WrappedTProcessorFunction)(nil)
)
// ClientMiddleware can be passed to WrapClient in order to wrap TClient calls
// with custom middleware.
type ClientMiddleware func(TClient) TClient
// WrappedTClient is a convenience struct that implements the TClient interface
// using inner Wrapped function.
//
// This is provided to aid in developing ClientMiddleware.
type WrappedTClient struct {
Wrapped func(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error)
}
// Call implements the TClient interface by calling and returning c.Wrapped.
func (c WrappedTClient) Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error) {
return c.Wrapped(ctx, method, args, result)
}
// verify that WrappedTClient implements TClient
var (
_ TClient = WrappedTClient{}
_ TClient = (*WrappedTClient)(nil)
)
// WrapClient wraps the given TClient in the given middlewares.
//
// Middlewares will be called in the order that they are defined:
//
// 1. Middlewares[0]
// 2. Middlewares[1]
// ...
// N. Middlewares[n]
func WrapClient(client TClient, middlewares ...ClientMiddleware) TClient {
// Add middlewares in reverse so the first in the list is the outermost.
for i := len(middlewares) - 1; i >= 0; i-- {
client = middlewares[i](client)
}
return client
}
@@ -68,11 +68,11 @@ func NewTMultiplexedProtocol(protocol TProtocol, serviceName string) *TMultiplex
}
}
func (t *TMultiplexedProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
func (t *TMultiplexedProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error {
if typeId == CALL || typeId == ONEWAY {
return t.TProtocol.WriteMessageBegin(t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid)
return t.TProtocol.WriteMessageBegin(ctx, t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid)
} else {
return t.TProtocol.WriteMessageBegin(name, typeId, seqid)
return t.TProtocol.WriteMessageBegin(ctx, name, typeId, seqid)
}
}
@@ -117,6 +117,67 @@ func NewTMultiplexedProcessor() *TMultiplexedProcessor {
}
}
// ProcessorMap returns a mapping of "{ProcessorName}{MULTIPLEXED_SEPARATOR}{FunctionName}"
// to TProcessorFunction for any registered processors. If there is also a
// DefaultProcessor, the keys for the methods on that processor will simply be
// "{FunctionName}". If the TMultiplexedProcessor has both a DefaultProcessor and
// other registered processors, then the keys will be a mix of both formats.
//
// The implementation differs with other TProcessors in that the map returned is
// a new map, while most TProcessors just return their internal mapping directly.
// This means that edits to the map returned by this implementation of ProcessorMap
// will not affect the underlying mapping within the TMultiplexedProcessor.
func (t *TMultiplexedProcessor) ProcessorMap() map[string]TProcessorFunction {
processorFuncMap := make(map[string]TProcessorFunction)
for name, processor := range t.serviceProcessorMap {
for method, processorFunc := range processor.ProcessorMap() {
processorFuncName := name + MULTIPLEXED_SEPARATOR + method
processorFuncMap[processorFuncName] = processorFunc
}
}
if t.DefaultProcessor != nil {
for method, processorFunc := range t.DefaultProcessor.ProcessorMap() {
processorFuncMap[method] = processorFunc
}
}
return processorFuncMap
}
// AddToProcessorMap updates the underlying TProcessor ProccessorMaps depending on
// the format of "name".
//
// If "name" is in the format "{ProcessorName}{MULTIPLEXED_SEPARATOR}{FunctionName}",
// then it sets the given TProcessorFunction on the inner TProcessor with the
// ProcessorName component using the FunctionName component.
//
// If "name" is just in the format "{FunctionName}", that is to say there is no
// MULTIPLEXED_SEPARATOR, and the TMultiplexedProcessor has a DefaultProcessor
// configured, then it will set the given TProcessorFunction on the DefaultProcessor
// using the given name.
//
// If there is not a TProcessor available for the given name, then this function
// does nothing. This can happen when there is no TProcessor registered for
// the given ProcessorName or if all that is given is the FunctionName and there
// is no DefaultProcessor set.
func (t *TMultiplexedProcessor) AddToProcessorMap(name string, processorFunc TProcessorFunction) {
components := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2)
if len(components) != 2 {
if t.DefaultProcessor != nil && len(components) == 1 {
t.DefaultProcessor.AddToProcessorMap(components[0], processorFunc)
}
return
}
processorName := components[0]
funcName := components[1]
if processor, ok := t.serviceProcessorMap[processorName]; ok {
processor.AddToProcessorMap(funcName, processorFunc)
}
}
// verify that TMultiplexedProcessor implements TProcessor
var _ TProcessor = (*TMultiplexedProcessor)(nil)
func (t *TMultiplexedProcessor) RegisterDefault(processor TProcessor) {
t.DefaultProcessor = processor
}
@@ -129,12 +190,12 @@ func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProces
}
func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
name, typeId, seqid, err := in.ReadMessageBegin()
name, typeId, seqid, err := in.ReadMessageBegin(ctx)
if err != nil {
return false, err
return false, NewTProtocolException(err)
}
if typeId != CALL && typeId != ONEWAY {
return false, fmt.Errorf("Unexpected message type %v", typeId)
return false, NewTProtocolException(fmt.Errorf("Unexpected message type %v", typeId))
}
//extract the service name
v := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2)
@@ -143,11 +204,17 @@ func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol)
smb := NewStoredMessageProtocol(in, name, typeId, seqid)
return t.DefaultProcessor.Process(ctx, smb, out)
}
return false, fmt.Errorf("Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", name)
return false, NewTProtocolException(fmt.Errorf(
"Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?",
name,
))
}
actualProcessor, ok := t.serviceProcessorMap[v[0]]
if !ok {
return false, fmt.Errorf("Service name not found: %s. Did you forget to call registerProcessor()?", v[0])
return false, NewTProtocolException(fmt.Errorf(
"Service name not found: %s. Did you forget to call registerProcessor()?",
v[0],
))
}
smb := NewStoredMessageProtocol(in, v[1], typeId, seqid)
return actualProcessor.Process(ctx, smb, out)
@@ -165,6 +232,6 @@ func NewStoredMessageProtocol(protocol TProtocol, name string, typeId TMessageTy
return &storedMessageProtocol{protocol, name, typeId, seqid}
}
func (s *storedMessageProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
func (s *storedMessageProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error) {
return s.name, s.typeId, s.seqid, nil
}
@@ -69,14 +69,14 @@ func NewNumericFromDouble(dValue float64) Numeric {
func NewNumericFromI64(iValue int64) Numeric {
dValue := float64(iValue)
sValue := string(iValue)
sValue := strconv.FormatInt(iValue, 10)
isNil := false
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromI32(iValue int32) Numeric {
dValue := float64(iValue)
sValue := string(iValue)
sValue := strconv.FormatInt(int64(iValue), 10)
isNil := false
return &numeric{iValue: int64(iValue), dValue: dValue, sValue: sValue, isNil: isNil}
}
@@ -25,6 +25,16 @@ import "context"
// writes to some output stream.
type TProcessor interface {
Process(ctx context.Context, in, out TProtocol) (bool, TException)
// ProcessorMap returns a map of thrift method names to TProcessorFunctions.
ProcessorMap() map[string]TProcessorFunction
// AddToProcessorMap adds the given TProcessorFunction to the internal
// processor map at the given key.
//
// If one is already set at the given key, it will be replaced with the new
// TProcessorFunction.
AddToProcessorMap(string, TProcessorFunction)
}
type TProcessorFunction interface {
@@ -31,50 +31,50 @@ const (
)
type TProtocol interface {
WriteMessageBegin(name string, typeId TMessageType, seqid int32) error
WriteMessageEnd() error
WriteStructBegin(name string) error
WriteStructEnd() error
WriteFieldBegin(name string, typeId TType, id int16) error
WriteFieldEnd() error
WriteFieldStop() error
WriteMapBegin(keyType TType, valueType TType, size int) error
WriteMapEnd() error
WriteListBegin(elemType TType, size int) error
WriteListEnd() error
WriteSetBegin(elemType TType, size int) error
WriteSetEnd() error
WriteBool(value bool) error
WriteByte(value int8) error
WriteI16(value int16) error
WriteI32(value int32) error
WriteI64(value int64) error
WriteDouble(value float64) error
WriteString(value string) error
WriteBinary(value []byte) error
WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqid int32) error
WriteMessageEnd(ctx context.Context) error
WriteStructBegin(ctx context.Context, name string) error
WriteStructEnd(ctx context.Context) error
WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error
WriteFieldEnd(ctx context.Context) error
WriteFieldStop(ctx context.Context) error
WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error
WriteMapEnd(ctx context.Context) error
WriteListBegin(ctx context.Context, elemType TType, size int) error
WriteListEnd(ctx context.Context) error
WriteSetBegin(ctx context.Context, elemType TType, size int) error
WriteSetEnd(ctx context.Context) error
WriteBool(ctx context.Context, value bool) error
WriteByte(ctx context.Context, value int8) error
WriteI16(ctx context.Context, value int16) error
WriteI32(ctx context.Context, value int32) error
WriteI64(ctx context.Context, value int64) error
WriteDouble(ctx context.Context, value float64) error
WriteString(ctx context.Context, value string) error
WriteBinary(ctx context.Context, value []byte) error
ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error)
ReadMessageEnd() error
ReadStructBegin() (name string, err error)
ReadStructEnd() error
ReadFieldBegin() (name string, typeId TType, id int16, err error)
ReadFieldEnd() error
ReadMapBegin() (keyType TType, valueType TType, size int, err error)
ReadMapEnd() error
ReadListBegin() (elemType TType, size int, err error)
ReadListEnd() error
ReadSetBegin() (elemType TType, size int, err error)
ReadSetEnd() error
ReadBool() (value bool, err error)
ReadByte() (value int8, err error)
ReadI16() (value int16, err error)
ReadI32() (value int32, err error)
ReadI64() (value int64, err error)
ReadDouble() (value float64, err error)
ReadString() (value string, err error)
ReadBinary() (value []byte, err error)
ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqid int32, err error)
ReadMessageEnd(ctx context.Context) error
ReadStructBegin(ctx context.Context) (name string, err error)
ReadStructEnd(ctx context.Context) error
ReadFieldBegin(ctx context.Context) (name string, typeId TType, id int16, err error)
ReadFieldEnd(ctx context.Context) error
ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error)
ReadMapEnd(ctx context.Context) error
ReadListBegin(ctx context.Context) (elemType TType, size int, err error)
ReadListEnd(ctx context.Context) error
ReadSetBegin(ctx context.Context) (elemType TType, size int, err error)
ReadSetEnd(ctx context.Context) error
ReadBool(ctx context.Context) (value bool, err error)
ReadByte(ctx context.Context) (value int8, err error)
ReadI16(ctx context.Context) (value int16, err error)
ReadI32(ctx context.Context) (value int32, err error)
ReadI64(ctx context.Context) (value int64, err error)
ReadDouble(ctx context.Context) (value float64, err error)
ReadString(ctx context.Context) (value string, err error)
ReadBinary(ctx context.Context) (value []byte, err error)
Skip(fieldType TType) (err error)
Skip(ctx context.Context, fieldType TType) (err error)
Flush(ctx context.Context) (err error)
Transport() TTransport
@@ -84,12 +84,12 @@ type TProtocol interface {
const DEFAULT_RECURSION_DEPTH = 64
// Skips over the next data element from the provided input TProtocol object.
func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) {
return Skip(prot, typeId, DEFAULT_RECURSION_DEPTH)
func SkipDefaultDepth(ctx context.Context, prot TProtocol, typeId TType) (err error) {
return Skip(ctx, prot, typeId, DEFAULT_RECURSION_DEPTH)
}
// Skips over the next data element from the provided input TProtocol object.
func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) {
func Skip(ctx context.Context, self TProtocol, fieldType TType, maxDepth int) (err error) {
if maxDepth <= 0 {
return NewTProtocolExceptionWithType(DEPTH_LIMIT, errors.New("Depth limit exceeded"))
@@ -97,79 +97,79 @@ func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) {
switch fieldType {
case BOOL:
_, err = self.ReadBool()
_, err = self.ReadBool(ctx)
return
case BYTE:
_, err = self.ReadByte()
_, err = self.ReadByte(ctx)
return
case I16:
_, err = self.ReadI16()
_, err = self.ReadI16(ctx)
return
case I32:
_, err = self.ReadI32()
_, err = self.ReadI32(ctx)
return
case I64:
_, err = self.ReadI64()
_, err = self.ReadI64(ctx)
return
case DOUBLE:
_, err = self.ReadDouble()
_, err = self.ReadDouble(ctx)
return
case STRING:
_, err = self.ReadString()
_, err = self.ReadString(ctx)
return
case STRUCT:
if _, err = self.ReadStructBegin(); err != nil {
if _, err = self.ReadStructBegin(ctx); err != nil {
return err
}
for {
_, typeId, _, _ := self.ReadFieldBegin()
_, typeId, _, _ := self.ReadFieldBegin(ctx)
if typeId == STOP {
break
}
err := Skip(self, typeId, maxDepth-1)
err := Skip(ctx, self, typeId, maxDepth-1)
if err != nil {
return err
}
self.ReadFieldEnd()
self.ReadFieldEnd(ctx)
}
return self.ReadStructEnd()
return self.ReadStructEnd(ctx)
case MAP:
keyType, valueType, size, err := self.ReadMapBegin()
keyType, valueType, size, err := self.ReadMapBegin(ctx)
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, keyType, maxDepth-1)
err := Skip(ctx, self, keyType, maxDepth-1)
if err != nil {
return err
}
self.Skip(valueType)
self.Skip(ctx, valueType)
}
return self.ReadMapEnd()
return self.ReadMapEnd(ctx)
case SET:
elemType, size, err := self.ReadSetBegin()
elemType, size, err := self.ReadSetBegin(ctx)
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, elemType, maxDepth-1)
err := Skip(ctx, self, elemType, maxDepth-1)
if err != nil {
return err
}
}
return self.ReadSetEnd()
return self.ReadSetEnd(ctx)
case LIST:
elemType, size, err := self.ReadListBegin()
elemType, size, err := self.ReadListBegin(ctx)
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, elemType, maxDepth-1)
err := Skip(ctx, self, elemType, maxDepth-1)
if err != nil {
return err
}
}
return self.ReadListEnd()
return self.ReadListEnd(ctx)
default:
return NewTProtocolExceptionWithType(INVALID_DATA, errors.New(fmt.Sprintf("Unknown data type %d", fieldType)))
}
@@ -21,6 +21,7 @@ package thrift
import (
"encoding/base64"
"errors"
)
// Thrift Protocol exception
@@ -40,8 +41,15 @@ const (
)
type tProtocolException struct {
typeId int
message string
typeId int
err error
msg string
}
var _ TProtocolException = (*tProtocolException)(nil)
func (tProtocolException) TExceptionType() TExceptionType {
return TExceptionTypeProtocol
}
func (p *tProtocolException) TypeId() int {
@@ -49,29 +57,48 @@ func (p *tProtocolException) TypeId() int {
}
func (p *tProtocolException) String() string {
return p.message
return p.msg
}
func (p *tProtocolException) Error() string {
return p.message
return p.msg
}
func (p *tProtocolException) Unwrap() error {
return p.err
}
func NewTProtocolException(err error) TProtocolException {
if err == nil {
return nil
}
if e, ok := err.(TProtocolException); ok {
return e
}
if _, ok := err.(base64.CorruptInputError); ok {
return &tProtocolException{INVALID_DATA, err.Error()}
if errors.As(err, new(base64.CorruptInputError)) {
return NewTProtocolExceptionWithType(INVALID_DATA, err)
}
return &tProtocolException{UNKNOWN_PROTOCOL_EXCEPTION, err.Error()}
return NewTProtocolExceptionWithType(UNKNOWN_PROTOCOL_EXCEPTION, err)
}
func NewTProtocolExceptionWithType(errType int, err error) TProtocolException {
if err == nil {
return nil
}
return &tProtocolException{errType, err.Error()}
return &tProtocolException{
typeId: errType,
err: err,
msg: err.Error(),
}
}
func prependTProtocolException(prepend string, err TProtocolException) TProtocolException {
return &tProtocolException{
typeId: err.TypeId(),
err: err,
msg: prepend + err.Error(),
}
}
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
)
// See https://godoc.org/context#WithValue on why do we need the unexported typedefs.
type responseHelperKey struct{}
// TResponseHelper defines a object with a set of helper functions that can be
// retrieved from the context object passed into server handler functions.
//
// Use GetResponseHelper to retrieve the injected TResponseHelper implementation
// from the context object.
//
// The zero value of TResponseHelper is valid with all helper functions being
// no-op.
type TResponseHelper struct {
// THeader related functions
*THeaderResponseHelper
}
// THeaderResponseHelper defines THeader related TResponseHelper functions.
//
// The zero value of *THeaderResponseHelper is valid with all helper functions
// being no-op.
type THeaderResponseHelper struct {
proto *THeaderProtocol
}
// NewTHeaderResponseHelper creates a new THeaderResponseHelper from the
// underlying TProtocol.
func NewTHeaderResponseHelper(proto TProtocol) *THeaderResponseHelper {
if hp, ok := proto.(*THeaderProtocol); ok {
return &THeaderResponseHelper{
proto: hp,
}
}
return nil
}
// SetHeader sets a response header.
//
// It's no-op if the underlying protocol/transport does not support THeader.
func (h *THeaderResponseHelper) SetHeader(key, value string) {
if h != nil && h.proto != nil {
h.proto.SetWriteHeader(key, value)
}
}
// ClearHeaders clears all the response headers previously set.
//
// It's no-op if the underlying protocol/transport does not support THeader.
func (h *THeaderResponseHelper) ClearHeaders() {
if h != nil && h.proto != nil {
h.proto.ClearWriteHeaders()
}
}
// GetResponseHelper retrieves the TResponseHelper implementation injected into
// the context object.
//
// If no helper was found in the context object, a nop helper with ok == false
// will be returned.
func GetResponseHelper(ctx context.Context) (helper TResponseHelper, ok bool) {
if v := ctx.Value(responseHelperKey{}); v != nil {
helper, ok = v.(TResponseHelper)
}
return
}
// SetResponseHelper injects TResponseHelper into the context object.
func SetResponseHelper(ctx context.Context, helper TResponseHelper) context.Context {
return context.WithValue(ctx, responseHelperKey{}, helper)
}
@@ -19,7 +19,10 @@
package thrift
import "io"
import (
"errors"
"io"
)
type RichTransport struct {
TTransport
@@ -49,7 +52,7 @@ func (r *RichTransport) RemainingBytes() (num_bytes uint64) {
func readByte(r io.Reader) (c byte, err error) {
v := [1]byte{0}
n, err := r.Read(v[0:1])
if n > 0 && (err == nil || err == io.EOF) {
if n > 0 && (err == nil || errors.Is(err, io.EOF)) {
return v[0], nil
}
if n > 0 && err != nil {
@@ -21,6 +21,7 @@ package thrift
import (
"context"
"sync"
)
type TSerializer struct {
@@ -29,23 +30,24 @@ type TSerializer struct {
}
type TStruct interface {
Write(p TProtocol) error
Read(p TProtocol) error
Write(ctx context.Context, p TProtocol) error
Read(ctx context.Context, p TProtocol) error
}
func NewTSerializer() *TSerializer {
transport := NewTMemoryBufferLen(1024)
protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport)
protocol := NewTBinaryProtocolTransport(transport)
return &TSerializer{
transport,
protocol}
Transport: transport,
Protocol: protocol,
}
}
func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
if err = msg.Write(ctx, t.Protocol); err != nil {
return
}
@@ -62,7 +64,7 @@ func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, e
func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
if err = msg.Write(ctx, t.Protocol); err != nil {
return
}
@@ -77,3 +79,58 @@ func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err err
b = append(b, t.Transport.Bytes()...)
return
}
// TSerializerPool is the thread-safe version of TSerializer, it uses resource
// pool of TSerializer under the hood.
//
// It must be initialized with either NewTSerializerPool or
// NewTSerializerPoolSizeFactory.
type TSerializerPool struct {
pool sync.Pool
}
// NewTSerializerPool creates a new TSerializerPool.
//
// NewTSerializer can be used as the arg here.
func NewTSerializerPool(f func() *TSerializer) *TSerializerPool {
return &TSerializerPool{
pool: sync.Pool{
New: func() interface{} {
return f()
},
},
}
}
// NewTSerializerPoolSizeFactory creates a new TSerializerPool with the given
// size and protocol factory.
//
// Note that the size is not the limit. The TMemoryBuffer underneath can grow
// larger than that. It just dictates the initial size.
func NewTSerializerPoolSizeFactory(size int, factory TProtocolFactory) *TSerializerPool {
return &TSerializerPool{
pool: sync.Pool{
New: func() interface{} {
transport := NewTMemoryBufferLen(size)
protocol := factory.GetProtocol(transport)
return &TSerializer{
Transport: transport,
Protocol: protocol,
}
},
},
}
}
func (t *TSerializerPool) WriteString(ctx context.Context, msg TStruct) (string, error) {
s := t.pool.Get().(*TSerializer)
defer t.pool.Put(s)
return s.WriteString(ctx, msg)
}
func (t *TSerializerPool) Write(ctx context.Context, msg TStruct) ([]byte, error) {
s := t.pool.Get().(*TSerializer)
defer t.pool.Put(s)
return s.Write(ctx, msg)
}
@@ -25,6 +25,7 @@ import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math"
@@ -34,12 +35,13 @@ import (
type _ParseContext int
const (
_CONTEXT_IN_TOPLEVEL _ParseContext = 1
_CONTEXT_IN_LIST_FIRST _ParseContext = 2
_CONTEXT_IN_LIST _ParseContext = 3
_CONTEXT_IN_OBJECT_FIRST _ParseContext = 4
_CONTEXT_IN_OBJECT_NEXT_KEY _ParseContext = 5
_CONTEXT_IN_OBJECT_NEXT_VALUE _ParseContext = 6
_CONTEXT_INVALID _ParseContext = iota
_CONTEXT_IN_TOPLEVEL // 1
_CONTEXT_IN_LIST_FIRST // 2
_CONTEXT_IN_LIST // 3
_CONTEXT_IN_OBJECT_FIRST // 4
_CONTEXT_IN_OBJECT_NEXT_KEY // 5
_CONTEXT_IN_OBJECT_NEXT_VALUE // 6
)
func (p _ParseContext) String() string {
@@ -60,6 +62,32 @@ func (p _ParseContext) String() string {
return "UNKNOWN-PARSE-CONTEXT"
}
type jsonContextStack []_ParseContext
func (s *jsonContextStack) push(v _ParseContext) {
*s = append(*s, v)
}
func (s jsonContextStack) peek() (v _ParseContext, ok bool) {
l := len(s)
if l <= 0 {
return
}
return s[l-1], true
}
func (s *jsonContextStack) pop() (v _ParseContext, ok bool) {
l := len(*s)
if l <= 0 {
return
}
v = (*s)[l-1]
*s = (*s)[0 : l-1]
return v, true
}
var errEmptyJSONContextStack = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Unexpected empty json protocol context stack"))
// Simple JSON protocol implementation for thrift.
//
// This protocol produces/consumes a simple output format
@@ -69,8 +97,8 @@ func (p _ParseContext) String() string {
type TSimpleJSONProtocol struct {
trans TTransport
parseContextStack []int
dumpContext []int
parseContextStack jsonContextStack
dumpContext jsonContextStack
writer *bufio.Writer
reader *bufio.Reader
@@ -82,8 +110,8 @@ func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol {
writer: bufio.NewWriter(t),
reader: bufio.NewReader(t),
}
v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
v.parseContextStack.push(_CONTEXT_IN_TOPLEVEL)
v.dumpContext.push(_CONTEXT_IN_TOPLEVEL)
return v
}
@@ -156,114 +184,113 @@ func mismatch(expected, actual string) error {
return fmt.Errorf("Expected '%s' but found '%s' while parsing JSON.", expected, actual)
}
func (p *TSimpleJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
func (p *TSimpleJSONProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error {
p.resetContextStack() // THRIFT-3735
if e := p.OutputListBegin(); e != nil {
return e
}
if e := p.WriteString(name); e != nil {
if e := p.WriteString(ctx, name); e != nil {
return e
}
if e := p.WriteByte(int8(typeId)); e != nil {
if e := p.WriteByte(ctx, int8(typeId)); e != nil {
return e
}
if e := p.WriteI32(seqId); e != nil {
if e := p.WriteI32(ctx, seqId); e != nil {
return e
}
return nil
}
func (p *TSimpleJSONProtocol) WriteMessageEnd() error {
func (p *TSimpleJSONProtocol) WriteMessageEnd(ctx context.Context) error {
return p.OutputListEnd()
}
func (p *TSimpleJSONProtocol) WriteStructBegin(name string) error {
func (p *TSimpleJSONProtocol) WriteStructBegin(ctx context.Context, name string) error {
if e := p.OutputObjectBegin(); e != nil {
return e
}
return nil
}
func (p *TSimpleJSONProtocol) WriteStructEnd() error {
func (p *TSimpleJSONProtocol) WriteStructEnd(ctx context.Context) error {
return p.OutputObjectEnd()
}
func (p *TSimpleJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
if e := p.WriteString(name); e != nil {
func (p *TSimpleJSONProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
if e := p.WriteString(ctx, name); e != nil {
return e
}
return nil
}
func (p *TSimpleJSONProtocol) WriteFieldEnd() error {
//return p.OutputListEnd()
func (p *TSimpleJSONProtocol) WriteFieldEnd(ctx context.Context) error {
return nil
}
func (p *TSimpleJSONProtocol) WriteFieldStop() error { return nil }
func (p *TSimpleJSONProtocol) WriteFieldStop(ctx context.Context) error { return nil }
func (p *TSimpleJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
func (p *TSimpleJSONProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
if e := p.WriteByte(int8(keyType)); e != nil {
if e := p.WriteByte(ctx, int8(keyType)); e != nil {
return e
}
if e := p.WriteByte(int8(valueType)); e != nil {
if e := p.WriteByte(ctx, int8(valueType)); e != nil {
return e
}
return p.WriteI32(int32(size))
return p.WriteI32(ctx, int32(size))
}
func (p *TSimpleJSONProtocol) WriteMapEnd() error {
func (p *TSimpleJSONProtocol) WriteMapEnd(ctx context.Context) error {
return p.OutputListEnd()
}
func (p *TSimpleJSONProtocol) WriteListBegin(elemType TType, size int) error {
func (p *TSimpleJSONProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
func (p *TSimpleJSONProtocol) WriteListEnd() error {
func (p *TSimpleJSONProtocol) WriteListEnd(ctx context.Context) error {
return p.OutputListEnd()
}
func (p *TSimpleJSONProtocol) WriteSetBegin(elemType TType, size int) error {
func (p *TSimpleJSONProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
func (p *TSimpleJSONProtocol) WriteSetEnd() error {
func (p *TSimpleJSONProtocol) WriteSetEnd(ctx context.Context) error {
return p.OutputListEnd()
}
func (p *TSimpleJSONProtocol) WriteBool(b bool) error {
func (p *TSimpleJSONProtocol) WriteBool(ctx context.Context, b bool) error {
return p.OutputBool(b)
}
func (p *TSimpleJSONProtocol) WriteByte(b int8) error {
return p.WriteI32(int32(b))
func (p *TSimpleJSONProtocol) WriteByte(ctx context.Context, b int8) error {
return p.WriteI32(ctx, int32(b))
}
func (p *TSimpleJSONProtocol) WriteI16(v int16) error {
return p.WriteI32(int32(v))
func (p *TSimpleJSONProtocol) WriteI16(ctx context.Context, v int16) error {
return p.WriteI32(ctx, int32(v))
}
func (p *TSimpleJSONProtocol) WriteI32(v int32) error {
func (p *TSimpleJSONProtocol) WriteI32(ctx context.Context, v int32) error {
return p.OutputI64(int64(v))
}
func (p *TSimpleJSONProtocol) WriteI64(v int64) error {
func (p *TSimpleJSONProtocol) WriteI64(ctx context.Context, v int64) error {
return p.OutputI64(int64(v))
}
func (p *TSimpleJSONProtocol) WriteDouble(v float64) error {
func (p *TSimpleJSONProtocol) WriteDouble(ctx context.Context, v float64) error {
return p.OutputF64(v)
}
func (p *TSimpleJSONProtocol) WriteString(v string) error {
func (p *TSimpleJSONProtocol) WriteString(ctx context.Context, v string) error {
return p.OutputString(v)
}
func (p *TSimpleJSONProtocol) WriteBinary(v []byte) error {
func (p *TSimpleJSONProtocol) WriteBinary(ctx context.Context, v []byte) error {
// JSON library only takes in a string,
// not an arbitrary byte array, to ensure bytes are transmitted
// efficiently we must convert this into a valid JSON string
@@ -289,39 +316,39 @@ func (p *TSimpleJSONProtocol) WriteBinary(v []byte) error {
}
// Reading methods.
func (p *TSimpleJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
func (p *TSimpleJSONProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {
p.resetContextStack() // THRIFT-3735
if isNull, err := p.ParseListBegin(); isNull || err != nil {
return name, typeId, seqId, err
}
if name, err = p.ReadString(); err != nil {
if name, err = p.ReadString(ctx); err != nil {
return name, typeId, seqId, err
}
bTypeId, err := p.ReadByte()
bTypeId, err := p.ReadByte(ctx)
typeId = TMessageType(bTypeId)
if err != nil {
return name, typeId, seqId, err
}
if seqId, err = p.ReadI32(); err != nil {
if seqId, err = p.ReadI32(ctx); err != nil {
return name, typeId, seqId, err
}
return name, typeId, seqId, nil
}
func (p *TSimpleJSONProtocol) ReadMessageEnd() error {
func (p *TSimpleJSONProtocol) ReadMessageEnd(ctx context.Context) error {
return p.ParseListEnd()
}
func (p *TSimpleJSONProtocol) ReadStructBegin() (name string, err error) {
func (p *TSimpleJSONProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
_, err = p.ParseObjectStart()
return "", err
}
func (p *TSimpleJSONProtocol) ReadStructEnd() error {
func (p *TSimpleJSONProtocol) ReadStructEnd(ctx context.Context) error {
return p.ParseObjectEnd()
}
func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
func (p *TSimpleJSONProtocol) ReadFieldBegin(ctx context.Context) (string, TType, int16, error) {
if err := p.ParsePreValue(); err != nil {
return "", STOP, 0, err
}
@@ -340,21 +367,6 @@ func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
return name, STOP, 0, err
}
return name, STOP, -1, p.ParsePostValue()
/*
if err = p.ParsePostValue(); err != nil {
return name, STOP, 0, err
}
if isNull, err := p.ParseListBegin(); isNull || err != nil {
return name, STOP, 0, err
}
bType, err := p.ReadByte()
thetype := TType(bType)
if err != nil {
return name, thetype, 0, err
}
id, err := p.ReadI16()
return name, thetype, id, err
*/
}
e := fmt.Errorf("Expected \"}\" or '\"', but found: '%s'", string(b))
return "", STOP, 0, NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -362,57 +374,56 @@ func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
return "", STOP, 0, NewTProtocolException(io.EOF)
}
func (p *TSimpleJSONProtocol) ReadFieldEnd() error {
func (p *TSimpleJSONProtocol) ReadFieldEnd(ctx context.Context) error {
return nil
//return p.ParseListEnd()
}
func (p *TSimpleJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) {
func (p *TSimpleJSONProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, VOID, 0, e
}
// read keyType
bKeyType, e := p.ReadByte()
bKeyType, e := p.ReadByte(ctx)
keyType = TType(bKeyType)
if e != nil {
return keyType, valueType, size, e
}
// read valueType
bValueType, e := p.ReadByte()
bValueType, e := p.ReadByte(ctx)
valueType = TType(bValueType)
if e != nil {
return keyType, valueType, size, e
}
// read size
iSize, err := p.ReadI64()
iSize, err := p.ReadI64(ctx)
size = int(iSize)
return keyType, valueType, size, err
}
func (p *TSimpleJSONProtocol) ReadMapEnd() error {
func (p *TSimpleJSONProtocol) ReadMapEnd(ctx context.Context) error {
return p.ParseListEnd()
}
func (p *TSimpleJSONProtocol) ReadListBegin() (elemType TType, size int, e error) {
func (p *TSimpleJSONProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
func (p *TSimpleJSONProtocol) ReadListEnd() error {
func (p *TSimpleJSONProtocol) ReadListEnd(ctx context.Context) error {
return p.ParseListEnd()
}
func (p *TSimpleJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) {
func (p *TSimpleJSONProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
func (p *TSimpleJSONProtocol) ReadSetEnd() error {
func (p *TSimpleJSONProtocol) ReadSetEnd(ctx context.Context) error {
return p.ParseListEnd()
}
func (p *TSimpleJSONProtocol) ReadBool() (bool, error) {
func (p *TSimpleJSONProtocol) ReadBool(ctx context.Context) (bool, error) {
var value bool
if err := p.ParsePreValue(); err != nil {
@@ -467,32 +478,32 @@ func (p *TSimpleJSONProtocol) ReadBool() (bool, error) {
return value, p.ParsePostValue()
}
func (p *TSimpleJSONProtocol) ReadByte() (int8, error) {
v, err := p.ReadI64()
func (p *TSimpleJSONProtocol) ReadByte(ctx context.Context) (int8, error) {
v, err := p.ReadI64(ctx)
return int8(v), err
}
func (p *TSimpleJSONProtocol) ReadI16() (int16, error) {
v, err := p.ReadI64()
func (p *TSimpleJSONProtocol) ReadI16(ctx context.Context) (int16, error) {
v, err := p.ReadI64(ctx)
return int16(v), err
}
func (p *TSimpleJSONProtocol) ReadI32() (int32, error) {
v, err := p.ReadI64()
func (p *TSimpleJSONProtocol) ReadI32(ctx context.Context) (int32, error) {
v, err := p.ReadI64(ctx)
return int32(v), err
}
func (p *TSimpleJSONProtocol) ReadI64() (int64, error) {
func (p *TSimpleJSONProtocol) ReadI64(ctx context.Context) (int64, error) {
v, _, err := p.ParseI64()
return v, err
}
func (p *TSimpleJSONProtocol) ReadDouble() (float64, error) {
func (p *TSimpleJSONProtocol) ReadDouble(ctx context.Context) (float64, error) {
v, _, err := p.ParseF64()
return v, err
}
func (p *TSimpleJSONProtocol) ReadString() (string, error) {
func (p *TSimpleJSONProtocol) ReadString(ctx context.Context) (string, error) {
var v string
if err := p.ParsePreValue(); err != nil {
return v, err
@@ -522,7 +533,7 @@ func (p *TSimpleJSONProtocol) ReadString() (string, error) {
return v, p.ParsePostValue()
}
func (p *TSimpleJSONProtocol) ReadBinary() ([]byte, error) {
func (p *TSimpleJSONProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
var v []byte
if err := p.ParsePreValue(); err != nil {
return nil, err
@@ -557,8 +568,8 @@ func (p *TSimpleJSONProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.writer.Flush())
}
func (p *TSimpleJSONProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
func (p *TSimpleJSONProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
return SkipDefaultDepth(ctx, p, fieldType)
}
func (p *TSimpleJSONProtocol) Transport() TTransport {
@@ -566,41 +577,41 @@ func (p *TSimpleJSONProtocol) Transport() TTransport {
}
func (p *TSimpleJSONProtocol) OutputPreValue() error {
cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1])
cxt, ok := p.dumpContext.peek()
if !ok {
return errEmptyJSONContextStack
}
switch cxt {
case _CONTEXT_IN_LIST, _CONTEXT_IN_OBJECT_NEXT_KEY:
if _, e := p.write(JSON_COMMA); e != nil {
return NewTProtocolException(e)
}
break
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
if _, e := p.write(JSON_COLON); e != nil {
return NewTProtocolException(e)
}
break
}
return nil
}
func (p *TSimpleJSONProtocol) OutputPostValue() error {
cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1])
cxt, ok := p.dumpContext.peek()
if !ok {
return errEmptyJSONContextStack
}
switch cxt {
case _CONTEXT_IN_LIST_FIRST:
p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST))
break
p.dumpContext.pop()
p.dumpContext.push(_CONTEXT_IN_LIST)
case _CONTEXT_IN_OBJECT_FIRST:
p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
break
p.dumpContext.pop()
p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
case _CONTEXT_IN_OBJECT_NEXT_KEY:
p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
break
p.dumpContext.pop()
p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_KEY))
break
p.dumpContext.pop()
p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_KEY)
}
return nil
}
@@ -615,10 +626,13 @@ func (p *TSimpleJSONProtocol) OutputBool(value bool) error {
} else {
v = string(JSON_FALSE)
}
switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
cxt, ok := p.dumpContext.peek()
if !ok {
return errEmptyJSONContextStack
}
switch cxt {
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
v = jsonQuote(v)
default:
}
if e := p.OutputStringData(v); e != nil {
return e
@@ -648,11 +662,14 @@ func (p *TSimpleJSONProtocol) OutputF64(value float64) error {
} else if math.IsInf(value, -1) {
v = string(JSON_QUOTE) + JSON_NEGATIVE_INFINITY + string(JSON_QUOTE)
} else {
cxt, ok := p.dumpContext.peek()
if !ok {
return errEmptyJSONContextStack
}
v = strconv.FormatFloat(value, 'g', -1, 64)
switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
switch cxt {
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
v = string(JSON_QUOTE) + v + string(JSON_QUOTE)
default:
}
}
if e := p.OutputStringData(v); e != nil {
@@ -665,11 +682,14 @@ func (p *TSimpleJSONProtocol) OutputI64(value int64) error {
if e := p.OutputPreValue(); e != nil {
return e
}
cxt, ok := p.dumpContext.peek()
if !ok {
return errEmptyJSONContextStack
}
v := strconv.FormatInt(value, 10)
switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
switch cxt {
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
v = jsonQuote(v)
default:
}
if e := p.OutputStringData(v); e != nil {
return e
@@ -699,7 +719,7 @@ func (p *TSimpleJSONProtocol) OutputObjectBegin() error {
if _, e := p.write(JSON_LBRACE); e != nil {
return NewTProtocolException(e)
}
p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_FIRST))
p.dumpContext.push(_CONTEXT_IN_OBJECT_FIRST)
return nil
}
@@ -707,7 +727,10 @@ func (p *TSimpleJSONProtocol) OutputObjectEnd() error {
if _, e := p.write(JSON_RBRACE); e != nil {
return NewTProtocolException(e)
}
p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
_, ok := p.dumpContext.pop()
if !ok {
return errEmptyJSONContextStack
}
if e := p.OutputPostValue(); e != nil {
return e
}
@@ -721,7 +744,7 @@ func (p *TSimpleJSONProtocol) OutputListBegin() error {
if _, e := p.write(JSON_LBRACKET); e != nil {
return NewTProtocolException(e)
}
p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST_FIRST))
p.dumpContext.push(_CONTEXT_IN_LIST_FIRST)
return nil
}
@@ -729,7 +752,10 @@ func (p *TSimpleJSONProtocol) OutputListEnd() error {
if _, e := p.write(JSON_RBRACKET); e != nil {
return NewTProtocolException(e)
}
p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
_, ok := p.dumpContext.pop()
if !ok {
return errEmptyJSONContextStack
}
if e := p.OutputPostValue(); e != nil {
return e
}
@@ -740,10 +766,10 @@ func (p *TSimpleJSONProtocol) OutputElemListBegin(elemType TType, size int) erro
if e := p.OutputListBegin(); e != nil {
return e
}
if e := p.WriteByte(int8(elemType)); e != nil {
if e := p.OutputI64(int64(elemType)); e != nil {
return e
}
if e := p.WriteI64(int64(size)); e != nil {
if e := p.OutputI64(int64(size)); e != nil {
return e
}
return nil
@@ -753,7 +779,10 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
if e := p.readNonSignificantWhitespace(); e != nil {
return NewTProtocolException(e)
}
cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
cxt, ok := p.parseContextStack.peek()
if !ok {
return errEmptyJSONContextStack
}
b, _ := p.reader.Peek(1)
switch cxt {
case _CONTEXT_IN_LIST:
@@ -772,7 +801,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
return NewTProtocolExceptionWithType(INVALID_DATA, e)
}
}
break
case _CONTEXT_IN_OBJECT_NEXT_KEY:
if len(b) > 0 {
switch b[0] {
@@ -789,7 +817,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
return NewTProtocolExceptionWithType(INVALID_DATA, e)
}
}
break
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
if len(b) > 0 {
switch b[0] {
@@ -804,7 +831,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
return NewTProtocolExceptionWithType(INVALID_DATA, e)
}
}
break
}
return nil
}
@@ -813,20 +839,20 @@ func (p *TSimpleJSONProtocol) ParsePostValue() error {
if e := p.readNonSignificantWhitespace(); e != nil {
return NewTProtocolException(e)
}
cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
cxt, ok := p.parseContextStack.peek()
if !ok {
return errEmptyJSONContextStack
}
switch cxt {
case _CONTEXT_IN_LIST_FIRST:
p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST))
break
p.parseContextStack.pop()
p.parseContextStack.push(_CONTEXT_IN_LIST)
case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
break
p.parseContextStack.pop()
p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
case _CONTEXT_IN_OBJECT_NEXT_VALUE:
p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_KEY))
break
p.parseContextStack.pop()
p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_KEY)
}
return nil
}
@@ -979,7 +1005,7 @@ func (p *TSimpleJSONProtocol) ParseObjectStart() (bool, error) {
}
if len(b) > 0 && b[0] == JSON_LBRACE[0] {
p.reader.ReadByte()
p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST))
p.parseContextStack.push(_CONTEXT_IN_OBJECT_FIRST)
return false, nil
} else if p.safePeekContains(JSON_NULL) {
return true, nil
@@ -992,7 +1018,7 @@ func (p *TSimpleJSONProtocol) ParseObjectEnd() error {
if isNull, err := p.readIfNull(); isNull || err != nil {
return err
}
cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
cxt, _ := p.parseContextStack.peek()
if (cxt != _CONTEXT_IN_OBJECT_FIRST) && (cxt != _CONTEXT_IN_OBJECT_NEXT_KEY) {
e := fmt.Errorf("Expected to be in the Object Context, but not in Object Context (%d)", cxt)
return NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -1010,7 +1036,7 @@ func (p *TSimpleJSONProtocol) ParseObjectEnd() error {
break
}
}
p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
p.parseContextStack.pop()
return p.ParsePostValue()
}
@@ -1024,7 +1050,7 @@ func (p *TSimpleJSONProtocol) ParseListBegin() (isNull bool, err error) {
return false, err
}
if len(b) >= 1 && b[0] == JSON_LBRACKET[0] {
p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST_FIRST))
p.parseContextStack.push(_CONTEXT_IN_LIST_FIRST)
p.reader.ReadByte()
isNull = false
} else if p.safePeekContains(JSON_NULL) {
@@ -1039,12 +1065,12 @@ func (p *TSimpleJSONProtocol) ParseElemListBegin() (elemType TType, size int, e
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
bElemType, err := p.ReadByte()
bElemType, _, err := p.ParseI64()
elemType = TType(bElemType)
if err != nil {
return elemType, size, err
}
nSize, err2 := p.ReadI64()
nSize, _, err2 := p.ParseI64()
size = int(nSize)
return elemType, size, err2
}
@@ -1053,7 +1079,7 @@ func (p *TSimpleJSONProtocol) ParseListEnd() error {
if isNull, err := p.readIfNull(); isNull || err != nil {
return err
}
cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
cxt, _ := p.parseContextStack.peek()
if cxt != _CONTEXT_IN_LIST {
e := fmt.Errorf("Expected to be in the List Context, but not in List Context (%d)", cxt)
return NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -1071,8 +1097,10 @@ func (p *TSimpleJSONProtocol) ParseListEnd() error {
break
}
}
p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
if _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) == _CONTEXT_IN_TOPLEVEL {
p.parseContextStack.pop()
if cxt, ok := p.parseContextStack.peek(); !ok {
return errEmptyJSONContextStack
} else if cxt == _CONTEXT_IN_TOPLEVEL {
return nil
}
return p.ParsePostValue()
@@ -1325,8 +1353,8 @@ func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool {
// Reset the context stack to its initial state.
func (p *TSimpleJSONProtocol) resetContextStack() {
p.parseContextStack = []int{int(_CONTEXT_IN_TOPLEVEL)}
p.dumpContext = []int{int(_CONTEXT_IN_TOPLEVEL)}
p.parseContextStack = jsonContextStack{_CONTEXT_IN_TOPLEVEL}
p.dumpContext = jsonContextStack{_CONTEXT_IN_TOPLEVEL}
}
func (p *TSimpleJSONProtocol) write(b []byte) (int, error) {
@@ -1336,3 +1364,10 @@ func (p *TSimpleJSONProtocol) write(b []byte) (int, error) {
}
return n, err
}
// SetTConfiguration implements TConfigurationSetter for propagation.
func (p *TSimpleJSONProtocol) SetTConfiguration(conf *TConfiguration) {
PropagateTConfiguration(p.trans, conf)
}
var _ TConfigurationSetter = (*TSimpleJSONProtocol)(nil)
@@ -20,12 +20,34 @@
package thrift
import (
"log"
"runtime/debug"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
)
// ErrAbandonRequest is a special error server handler implementations can
// return to indicate that the request has been abandoned.
//
// TSimpleServer will check for this error, and close the client connection
// instead of writing the response/error back to the client.
//
// It shall only be used when the server handler implementation know that the
// client already abandoned the request (by checking that the passed in context
// is already canceled, for example).
var ErrAbandonRequest = errors.New("request abandoned")
// ServerConnectivityCheckInterval defines the ticker interval used by
// connectivity check in thrift compiled TProcessorFunc implementations.
//
// It's defined as a variable instead of constant, so that thrift server
// implementations can change its value to control the behavior.
//
// If it's changed to <=0, the feature will be disabled.
var ServerConnectivityCheckInterval = time.Millisecond * 5
/*
* This is not a typical TSimpleServer as it is not blocked after accept a socket.
* It is more like a TThreadedServer that can handle different connections in different goroutines.
@@ -45,6 +67,8 @@ type TSimpleServer struct {
// Headers to auto forward in THeaderProtocol
forwardHeaders []string
logger Logger
}
func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer {
@@ -148,6 +172,14 @@ func (p *TSimpleServer) SetForwardHeaders(headers []string) {
p.forwardHeaders = keys
}
// SetLogger sets the logger used by this TSimpleServer.
//
// If no logger was set before Serve is called, a default logger using standard
// log library will be used.
func (p *TSimpleServer) SetLogger(logger Logger) {
p.logger = logger
}
func (p *TSimpleServer) innerAccept() (int32, error) {
client, err := p.serverTransport.Accept()
p.mu.Lock()
@@ -164,7 +196,7 @@ func (p *TSimpleServer) innerAccept() (int32, error) {
go func() {
defer p.wg.Done()
if err := p.processRequests(client); err != nil {
log.Println("error processing request:", err)
p.logger(fmt.Sprintf("error processing request: %v", err))
}
}()
}
@@ -184,6 +216,8 @@ func (p *TSimpleServer) AcceptLoop() error {
}
func (p *TSimpleServer) Serve() error {
p.logger = fallbackLogger(p.logger)
err := p.Listen()
if err != nil {
return err
@@ -204,7 +238,26 @@ func (p *TSimpleServer) Stop() error {
return nil
}
func (p *TSimpleServer) processRequests(client TTransport) error {
// If err is actually EOF, return nil, otherwise return err as-is.
func treatEOFErrorsAsNil(err error) error {
if err == nil {
return nil
}
if errors.Is(err, io.EOF) {
return nil
}
var te TTransportException
if errors.As(err, &te) && te.TypeId() == END_OF_FILE {
return nil
}
return err
}
func (p *TSimpleServer) processRequests(client TTransport) (err error) {
defer func() {
err = treatEOFErrorsAsNil(err)
}()
processor := p.processorFactory.GetProcessor(client)
inputTransport, err := p.inputTransportFactory.GetTransport(client)
if err != nil {
@@ -229,12 +282,6 @@ func (p *TSimpleServer) processRequests(client TTransport) error {
outputProtocol = p.outputProtocolFactory.GetProtocol(outputTransport)
}
defer func() {
if e := recover(); e != nil {
log.Printf("panic in processor: %s: %s", e, debug.Stack())
}
}()
if inputTransport != nil {
defer inputTransport.Close()
}
@@ -246,7 +293,12 @@ func (p *TSimpleServer) processRequests(client TTransport) error {
return nil
}
ctx := defaultCtx
ctx := SetResponseHelper(
defaultCtx,
TResponseHelper{
THeaderResponseHelper: NewTHeaderResponseHelper(outputProtocol),
},
)
if headerProtocol != nil {
// We need to call ReadFrame here, otherwise we won't
// get any headers on the AddReadTHeaderToContext call.
@@ -254,20 +306,22 @@ func (p *TSimpleServer) processRequests(client TTransport) error {
// ReadFrame is safe to be called multiple times so it
// won't break when it's called again later when we
// actually start to read the message.
if err := headerProtocol.ReadFrame(); err != nil {
if err := headerProtocol.ReadFrame(ctx); err != nil {
return err
}
ctx = AddReadTHeaderToContext(defaultCtx, headerProtocol.GetReadHeaders())
ctx = AddReadTHeaderToContext(ctx, headerProtocol.GetReadHeaders())
ctx = SetWriteHeaderList(ctx, p.forwardHeaders)
}
ok, err := processor.Process(ctx, inputProtocol, outputProtocol)
if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE {
return nil
} else if err != nil {
if errors.Is(err, ErrAbandonRequest) {
return client.Close()
}
if errors.As(err, new(TTransportException)) && err != nil {
return err
}
if err, ok := err.(TApplicationException); ok && err.TypeId() == UNKNOWN_METHOD {
var tae TApplicationException
if errors.As(err, &tae) && tae.TypeId() == UNKNOWN_METHOD {
continue
}
if !ok {
@@ -26,50 +26,116 @@ import (
)
type TSocket struct {
conn net.Conn
addr net.Addr
timeout time.Duration
conn *socketConn
addr net.Addr
cfg *TConfiguration
connectTimeout time.Duration
socketTimeout time.Duration
}
// NewTSocket creates a net.Conn-backed TTransport, given a host and port
// Deprecated: Use NewTSocketConf instead.
func NewTSocket(hostPort string) (*TSocket, error) {
return NewTSocketConf(hostPort, &TConfiguration{
noPropagation: true,
})
}
// NewTSocketConf creates a net.Conn-backed TTransport, given a host and port.
//
// Example:
// trans, err := thrift.NewTSocket("localhost:9090")
func NewTSocket(hostPort string) (*TSocket, error) {
return NewTSocketTimeout(hostPort, 0)
}
// NewTSocketTimeout creates a net.Conn-backed TTransport, given a host and port
// it also accepts a timeout as a time.Duration
func NewTSocketTimeout(hostPort string, timeout time.Duration) (*TSocket, error) {
//conn, err := net.DialTimeout(network, address, timeout)
//
// trans, err := thrift.NewTSocketConf("localhost:9090", &TConfiguration{
// ConnectTimeout: time.Second, // Use 0 for no timeout
// SocketTimeout: time.Second, // Use 0 for no timeout
// })
func NewTSocketConf(hostPort string, conf *TConfiguration) (*TSocket, error) {
addr, err := net.ResolveTCPAddr("tcp", hostPort)
if err != nil {
return nil, err
}
return NewTSocketFromAddrTimeout(addr, timeout), nil
return NewTSocketFromAddrConf(addr, conf), nil
}
// Creates a TSocket from a net.Addr
func NewTSocketFromAddrTimeout(addr net.Addr, timeout time.Duration) *TSocket {
return &TSocket{addr: addr, timeout: timeout}
// Deprecated: Use NewTSocketConf instead.
func NewTSocketTimeout(hostPort string, connTimeout time.Duration, soTimeout time.Duration) (*TSocket, error) {
return NewTSocketConf(hostPort, &TConfiguration{
ConnectTimeout: connTimeout,
SocketTimeout: soTimeout,
noPropagation: true,
})
}
// Creates a TSocket from an existing net.Conn
func NewTSocketFromConnTimeout(conn net.Conn, timeout time.Duration) *TSocket {
return &TSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout}
// NewTSocketFromAddrConf creates a TSocket from a net.Addr
func NewTSocketFromAddrConf(addr net.Addr, conf *TConfiguration) *TSocket {
return &TSocket{
addr: addr,
cfg: conf,
}
}
// Deprecated: Use NewTSocketFromAddrConf instead.
func NewTSocketFromAddrTimeout(addr net.Addr, connTimeout time.Duration, soTimeout time.Duration) *TSocket {
return NewTSocketFromAddrConf(addr, &TConfiguration{
ConnectTimeout: connTimeout,
SocketTimeout: soTimeout,
noPropagation: true,
})
}
// NewTSocketFromConnConf creates a TSocket from an existing net.Conn.
func NewTSocketFromConnConf(conn net.Conn, conf *TConfiguration) *TSocket {
return &TSocket{
conn: wrapSocketConn(conn),
addr: conn.RemoteAddr(),
cfg: conf,
}
}
// Deprecated: Use NewTSocketFromConnConf instead.
func NewTSocketFromConnTimeout(conn net.Conn, socketTimeout time.Duration) *TSocket {
return NewTSocketFromConnConf(conn, &TConfiguration{
SocketTimeout: socketTimeout,
noPropagation: true,
})
}
// SetTConfiguration implements TConfigurationSetter.
//
// It can be used to set connect and socket timeouts.
func (p *TSocket) SetTConfiguration(conf *TConfiguration) {
p.cfg = conf
}
// Sets the connect timeout
func (p *TSocket) SetConnTimeout(timeout time.Duration) error {
if p.cfg == nil {
p.cfg = &TConfiguration{
noPropagation: true,
}
}
p.cfg.ConnectTimeout = timeout
return nil
}
// Sets the socket timeout
func (p *TSocket) SetTimeout(timeout time.Duration) error {
p.timeout = timeout
func (p *TSocket) SetSocketTimeout(timeout time.Duration) error {
if p.cfg == nil {
p.cfg = &TConfiguration{
noPropagation: true,
}
}
p.cfg.SocketTimeout = timeout
return nil
}
func (p *TSocket) pushDeadline(read, write bool) {
var t time.Time
if p.timeout > 0 {
t = time.Now().Add(time.Duration(p.timeout))
if timeout := p.cfg.GetSocketTimeout(); timeout > 0 {
t = time.Now().Add(time.Duration(timeout))
}
if read && write {
p.conn.SetDeadline(t)
@@ -82,7 +148,7 @@ func (p *TSocket) pushDeadline(read, write bool) {
// Connects the socket, creating a new socket object if necessary.
func (p *TSocket) Open() error {
if p.IsOpen() {
if p.conn.isValid() {
return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
}
if p.addr == nil {
@@ -95,7 +161,11 @@ func (p *TSocket) Open() error {
return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
}
var err error
if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), p.timeout); err != nil {
if p.conn, err = createSocketConnFromReturn(net.DialTimeout(
p.addr.Network(),
p.addr.String(),
p.cfg.GetConnectTimeout(),
)); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
return nil
@@ -108,10 +178,7 @@ func (p *TSocket) Conn() net.Conn {
// Returns true if the connection is open
func (p *TSocket) IsOpen() bool {
if p.conn == nil {
return false
}
return true
return p.conn.IsOpen()
}
// Closes the socket.
@@ -133,16 +200,19 @@ func (p *TSocket) Addr() net.Addr {
}
func (p *TSocket) Read(buf []byte) (int, error) {
if !p.IsOpen() {
if !p.conn.isValid() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(true, false)
// NOTE: Calling any of p.IsOpen, p.conn.read0, or p.conn.IsOpen between
// p.pushDeadline and p.conn.Read could cause the deadline set inside
// p.pushDeadline being reset, thus need to be avoided.
n, err := p.conn.Read(buf)
return n, NewTTransportExceptionFromError(err)
}
func (p *TSocket) Write(buf []byte) (int, error) {
if !p.IsOpen() {
if !p.conn.isValid() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(false, true)
@@ -154,7 +224,7 @@ func (p *TSocket) Flush(ctx context.Context) error {
}
func (p *TSocket) Interrupt() error {
if !p.IsOpen() {
if !p.conn.isValid() {
return nil
}
return p.conn.Close()
@@ -164,3 +234,5 @@ func (p *TSocket) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the truth is, we just don't know unless framed is used
}
var _ TConfigurationSetter = (*TSocket)(nil)
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"net"
)
// socketConn is a wrapped net.Conn that tries to do connectivity check.
type socketConn struct {
net.Conn
buffer [1]byte
}
var _ net.Conn = (*socketConn)(nil)
// createSocketConnFromReturn is a language sugar to help create socketConn from
// return values of functions like net.Dial, tls.Dial, net.Listener.Accept, etc.
func createSocketConnFromReturn(conn net.Conn, err error) (*socketConn, error) {
if err != nil {
return nil, err
}
return &socketConn{
Conn: conn,
}, nil
}
// wrapSocketConn wraps an existing net.Conn into *socketConn.
func wrapSocketConn(conn net.Conn) *socketConn {
// In case conn is already wrapped,
// return it as-is and avoid double wrapping.
if sc, ok := conn.(*socketConn); ok {
return sc
}
return &socketConn{
Conn: conn,
}
}
// isValid checks whether there's a valid connection.
//
// It's nil safe, and returns false if sc itself is nil, or if the underlying
// connection is nil.
//
// It's the same as the previous implementation of TSocket.IsOpen and
// TSSLSocket.IsOpen before we added connectivity check.
func (sc *socketConn) isValid() bool {
return sc != nil && sc.Conn != nil
}
// IsOpen checks whether the connection is open.
//
// It's nil safe, and returns false if sc itself is nil, or if the underlying
// connection is nil.
//
// Otherwise, it tries to do a connectivity check and returns the result.
//
// It also has the side effect of resetting the previously set read deadline on
// the socket. As a result, it shouldn't be called between setting read deadline
// and doing actual read.
func (sc *socketConn) IsOpen() bool {
if !sc.isValid() {
return false
}
return sc.checkConn() == nil
}
// Read implements io.Reader.
//
// On Windows, it behaves the same as the underlying net.Conn.Read.
//
// On non-Windows, it treats len(p) == 0 as a connectivity check instead of
// readability check, which means instead of blocking until there's something to
// read (readability check), or always return (0, nil) (the default behavior of
// go's stdlib implementation on non-Windows), it never blocks, and will return
// an error if the connection is lost.
func (sc *socketConn) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, sc.read0()
}
return sc.Conn.Read(p)
}
@@ -0,0 +1,83 @@
// +build !windows
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"errors"
"io"
"syscall"
"time"
)
// We rely on this variable to be the zero time,
// but define it as global variable to avoid repetitive allocations.
// Please DO NOT mutate this variable in any way.
var zeroTime time.Time
func (sc *socketConn) read0() error {
return sc.checkConn()
}
func (sc *socketConn) checkConn() error {
syscallConn, ok := sc.Conn.(syscall.Conn)
if !ok {
// No way to check, return nil
return nil
}
// The reading about to be done here is non-blocking so we don't really
// need a read deadline. We just need to clear the previously set read
// deadline, if any.
sc.Conn.SetReadDeadline(zeroTime)
rc, err := syscallConn.SyscallConn()
if err != nil {
return err
}
var n int
if readErr := rc.Read(func(fd uintptr) bool {
n, _, err = syscall.Recvfrom(int(fd), sc.buffer[:], syscall.MSG_PEEK|syscall.MSG_DONTWAIT)
return true
}); readErr != nil {
return readErr
}
if n > 0 {
// We got something, which means we are good
return nil
}
if errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EWOULDBLOCK) {
// This means the connection is still open but we don't have
// anything to read right now.
return nil
}
if err != nil {
return err
}
// At this point, it means the other side already closed the connection.
return io.EOF
}
@@ -0,0 +1,34 @@
// +build windows
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
func (sc *socketConn) read0() error {
// On windows, we fallback to the default behavior of reading 0 bytes.
var p []byte
_, err := sc.Conn.Read(p)
return err
}
func (sc *socketConn) checkConn() error {
// On windows, we always return nil for this check.
return nil
}
@@ -27,54 +27,122 @@ import (
)
type TSSLSocket struct {
conn net.Conn
conn *socketConn
// hostPort contains host:port (e.g. "asdf.com:12345"). The field is
// only valid if addr is nil.
hostPort string
// addr is nil when hostPort is not "", and is only used when the
// TSSLSocket is constructed from a net.Addr.
addr net.Addr
timeout time.Duration
cfg *tls.Config
addr net.Addr
cfg *TConfiguration
}
// NewTSSLSocket creates a net.Conn-backed TTransport, given a host and port and tls Configuration
// NewTSSLSocketConf creates a net.Conn-backed TTransport, given a host and port.
//
// Example:
// trans, err := thrift.NewTSSLSocket("localhost:9090", nil)
func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) {
return NewTSSLSocketTimeout(hostPort, cfg, 0)
}
// NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port
// it also accepts a tls Configuration and a timeout as a time.Duration
func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, timeout time.Duration) (*TSSLSocket, error) {
if cfg.MinVersion == 0 {
//
// trans, err := thrift.NewTSSLSocketConf("localhost:9090", nil, &TConfiguration{
// ConnectTimeout: time.Second, // Use 0 for no timeout
// SocketTimeout: time.Second, // Use 0 for no timeout
// })
func NewTSSLSocketConf(hostPort string, conf *TConfiguration) (*TSSLSocket, error) {
if cfg := conf.GetTLSConfig(); cfg != nil && cfg.MinVersion == 0 {
cfg.MinVersion = tls.VersionTLS10
}
return &TSSLSocket{hostPort: hostPort, timeout: timeout, cfg: cfg}, nil
return &TSSLSocket{
hostPort: hostPort,
cfg: conf,
}, nil
}
// Creates a TSSLSocket from a net.Addr
func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, timeout time.Duration) *TSSLSocket {
return &TSSLSocket{addr: addr, timeout: timeout, cfg: cfg}
// Deprecated: Use NewTSSLSocketConf instead.
func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) {
return NewTSSLSocketConf(hostPort, &TConfiguration{
TLSConfig: cfg,
noPropagation: true,
})
}
// Creates a TSSLSocket from an existing net.Conn
func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket {
return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg}
// Deprecated: Use NewTSSLSocketConf instead.
func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, connectTimeout, socketTimeout time.Duration) (*TSSLSocket, error) {
return NewTSSLSocketConf(hostPort, &TConfiguration{
ConnectTimeout: connectTimeout,
SocketTimeout: socketTimeout,
TLSConfig: cfg,
noPropagation: true,
})
}
// NewTSSLSocketFromAddrConf creates a TSSLSocket from a net.Addr.
func NewTSSLSocketFromAddrConf(addr net.Addr, conf *TConfiguration) *TSSLSocket {
return &TSSLSocket{
addr: addr,
cfg: conf,
}
}
// Deprecated: Use NewTSSLSocketFromAddrConf instead.
func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, connectTimeout, socketTimeout time.Duration) *TSSLSocket {
return NewTSSLSocketFromAddrConf(addr, &TConfiguration{
ConnectTimeout: connectTimeout,
SocketTimeout: socketTimeout,
TLSConfig: cfg,
noPropagation: true,
})
}
// NewTSSLSocketFromConnConf creates a TSSLSocket from an existing net.Conn.
func NewTSSLSocketFromConnConf(conn net.Conn, conf *TConfiguration) *TSSLSocket {
return &TSSLSocket{
conn: wrapSocketConn(conn),
addr: conn.RemoteAddr(),
cfg: conf,
}
}
// Deprecated: Use NewTSSLSocketFromConnConf instead.
func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, socketTimeout time.Duration) *TSSLSocket {
return NewTSSLSocketFromConnConf(conn, &TConfiguration{
SocketTimeout: socketTimeout,
TLSConfig: cfg,
noPropagation: true,
})
}
// SetTConfiguration implements TConfigurationSetter.
//
// It can be used to change connect and socket timeouts.
func (p *TSSLSocket) SetTConfiguration(conf *TConfiguration) {
p.cfg = conf
}
// Sets the connect timeout
func (p *TSSLSocket) SetConnTimeout(timeout time.Duration) error {
if p.cfg == nil {
p.cfg = &TConfiguration{}
}
p.cfg.ConnectTimeout = timeout
return nil
}
// Sets the socket timeout
func (p *TSSLSocket) SetTimeout(timeout time.Duration) error {
p.timeout = timeout
func (p *TSSLSocket) SetSocketTimeout(timeout time.Duration) error {
if p.cfg == nil {
p.cfg = &TConfiguration{}
}
p.cfg.SocketTimeout = timeout
return nil
}
func (p *TSSLSocket) pushDeadline(read, write bool) {
var t time.Time
if p.timeout > 0 {
t = time.Now().Add(time.Duration(p.timeout))
if timeout := p.cfg.GetSocketTimeout(); timeout > 0 {
t = time.Now().Add(time.Duration(timeout))
}
if read && write {
p.conn.SetDeadline(t)
@@ -91,12 +159,18 @@ func (p *TSSLSocket) Open() error {
// If we have a hostname, we need to pass the hostname to tls.Dial for
// certificate hostname checks.
if p.hostPort != "" {
if p.conn, err = tls.DialWithDialer(&net.Dialer{
Timeout: p.timeout}, "tcp", p.hostPort, p.cfg); err != nil {
if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
&net.Dialer{
Timeout: p.cfg.GetConnectTimeout(),
},
"tcp",
p.hostPort,
p.cfg.GetTLSConfig(),
)); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
} else {
if p.IsOpen() {
if p.conn.isValid() {
return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
}
if p.addr == nil {
@@ -108,8 +182,14 @@ func (p *TSSLSocket) Open() error {
if len(p.addr.String()) == 0 {
return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
}
if p.conn, err = tls.DialWithDialer(&net.Dialer{
Timeout: p.timeout}, p.addr.Network(), p.addr.String(), p.cfg); err != nil {
if p.conn, err = createSocketConnFromReturn(tls.DialWithDialer(
&net.Dialer{
Timeout: p.cfg.GetConnectTimeout(),
},
p.addr.Network(),
p.addr.String(),
p.cfg.GetTLSConfig(),
)); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
}
@@ -123,10 +203,7 @@ func (p *TSSLSocket) Conn() net.Conn {
// Returns true if the connection is open
func (p *TSSLSocket) IsOpen() bool {
if p.conn == nil {
return false
}
return true
return p.conn.IsOpen()
}
// Closes the socket.
@@ -143,16 +220,19 @@ func (p *TSSLSocket) Close() error {
}
func (p *TSSLSocket) Read(buf []byte) (int, error) {
if !p.IsOpen() {
if !p.conn.isValid() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(true, false)
// NOTE: Calling any of p.IsOpen, p.conn.read0, or p.conn.IsOpen between
// p.pushDeadline and p.conn.Read could cause the deadline set inside
// p.pushDeadline being reset, thus need to be avoided.
n, err := p.conn.Read(buf)
return n, NewTTransportExceptionFromError(err)
}
func (p *TSSLSocket) Write(buf []byte) (int, error) {
if !p.IsOpen() {
if !p.conn.isValid() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(false, true)
@@ -164,7 +244,7 @@ func (p *TSSLSocket) Flush(ctx context.Context) error {
}
func (p *TSSLSocket) Interrupt() error {
if !p.IsOpen() {
if !p.conn.isValid() {
return nil
}
return p.conn.Close()
@@ -172,5 +252,7 @@ func (p *TSSLSocket) Interrupt() error {
func (p *TSSLSocket) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
return maxSize // the truth is, we just don't know unless framed is used
}
var _ TConfigurationSetter = (*TSSLSocket)(nil)
@@ -46,6 +46,13 @@ const (
type tTransportException struct {
typeId int
err error
msg string
}
var _ TTransportException = (*tTransportException)(nil)
func (tTransportException) TExceptionType() TExceptionType {
return TExceptionTypeTransport
}
func (p *tTransportException) TypeId() int {
@@ -53,15 +60,27 @@ func (p *tTransportException) TypeId() int {
}
func (p *tTransportException) Error() string {
return p.err.Error()
return p.msg
}
func (p *tTransportException) Err() error {
return p.err
}
func (p *tTransportException) Unwrap() error {
return p.err
}
func (p *tTransportException) Timeout() bool {
return p.typeId == TIMED_OUT
}
func NewTTransportException(t int, e string) TTransportException {
return &tTransportException{typeId: t, err: errors.New(e)}
return &tTransportException{
typeId: t,
err: errors.New(e),
msg: e,
}
}
func NewTTransportExceptionFromError(e error) TTransportException {
@@ -73,18 +92,40 @@ func NewTTransportExceptionFromError(e error) TTransportException {
return t
}
switch v := e.(type) {
case TTransportException:
return v
case timeoutable:
if v.Timeout() {
return &tTransportException{typeId: TIMED_OUT, err: e}
}
te := &tTransportException{
typeId: UNKNOWN_TRANSPORT_EXCEPTION,
err: e,
msg: e.Error(),
}
if e == io.EOF {
return &tTransportException{typeId: END_OF_FILE, err: e}
if isTimeoutError(e) {
te.typeId = TIMED_OUT
return te
}
return &tTransportException{typeId: UNKNOWN_TRANSPORT_EXCEPTION, err: e}
if errors.Is(e, io.EOF) {
te.typeId = END_OF_FILE
return te
}
return te
}
func prependTTransportException(prepend string, e TTransportException) TTransportException {
return &tTransportException{
typeId: e.TypeId(),
err: e,
msg: prepend + e.Error(),
}
}
// isTimeoutError returns true when err is an error caused by timeout.
//
// Note that this also includes TTransportException wrapped timeout errors.
func isTimeoutError(err error) bool {
var t timeoutable
if errors.As(err, &t) {
return t.Timeout()
}
return false
}
@@ -23,7 +23,6 @@ import (
"compress/zlib"
"context"
"io"
"log"
)
// TZlibTransportFactory is a factory for TZlibTransport instances
@@ -67,7 +66,6 @@ func NewTZlibTransportFactoryWithFactory(level int, factory TTransportFactory) *
func NewTZlibTransport(trans TTransport, level int) (*TZlibTransport, error) {
w, err := zlib.NewWriterLevel(trans, level)
if err != nil {
log.Println(err)
return nil, err
}
@@ -130,3 +128,10 @@ func (z *TZlibTransport) RemainingBytes() uint64 {
func (z *TZlibTransport) Write(p []byte) (int, error) {
return z.writer.Write(p)
}
// SetTConfiguration implements TConfigurationSetter for propagation.
func (z *TZlibTransport) SetTConfiguration(conf *TConfiguration) {
PropagateTConfiguration(z.transport, conf)
}
var _ TConfigurationSetter = (*TZlibTransport)(nil)
+2 -1
View File
@@ -16,6 +16,7 @@ package jaeger // import "go.opentelemetry.io/otel/exporters/trace/jaeger"
import (
"bytes"
"context"
"errors"
"fmt"
"io"
@@ -209,7 +210,7 @@ func (c *collectorUploader) upload(batch *gen.Batch) error {
func serialize(obj thrift.TStruct) (*bytes.Buffer, error) {
buf := thrift.NewTMemoryBuffer()
if err := obj.Write(thrift.NewTBinaryProtocolTransport(buf)); err != nil {
if err := obj.Write(context.Background(), thrift.NewTBinaryProtocolConf(buf, &thrift.TConfiguration{})); err != nil {
return nil, err
}
return buf.Buffer, nil