mirror of
https://github.com/go-kratos/kratos.git
synced 2025-01-10 00:29:01 +02:00
blademaster initial (#6)
This commit is contained in:
parent
1efe0a084e
commit
96d32e866a
85
pkg/net/http/blademaster/binding/binding.go
Normal file
85
pkg/net/http/blademaster/binding/binding.go
Normal file
@ -0,0 +1,85 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
|
||||
// MIME
|
||||
const (
|
||||
MIMEJSON = "application/json"
|
||||
MIMEHTML = "text/html"
|
||||
MIMEXML = "application/xml"
|
||||
MIMEXML2 = "text/xml"
|
||||
MIMEPlain = "text/plain"
|
||||
MIMEPOSTForm = "application/x-www-form-urlencoded"
|
||||
MIMEMultipartPOSTForm = "multipart/form-data"
|
||||
)
|
||||
|
||||
// Binding http binding request interface.
|
||||
type Binding interface {
|
||||
Name() string
|
||||
Bind(*http.Request, interface{}) error
|
||||
}
|
||||
|
||||
// StructValidator http validator interface.
|
||||
type StructValidator interface {
|
||||
// ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
|
||||
// If the received type is not a struct, any validation should be skipped and nil must be returned.
|
||||
// If the received type is a struct or pointer to a struct, the validation should be performed.
|
||||
// If the struct is not valid or the validation itself fails, a descriptive error should be returned.
|
||||
// Otherwise nil must be returned.
|
||||
ValidateStruct(interface{}) error
|
||||
|
||||
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
|
||||
// NOTE: if the key already exists, the previous validation function will be replaced.
|
||||
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
|
||||
RegisterValidation(string, validator.Func) error
|
||||
}
|
||||
|
||||
// Validator default validator.
|
||||
var Validator StructValidator = &defaultValidator{}
|
||||
|
||||
// Binding
|
||||
var (
|
||||
JSON = jsonBinding{}
|
||||
XML = xmlBinding{}
|
||||
Form = formBinding{}
|
||||
Query = queryBinding{}
|
||||
FormPost = formPostBinding{}
|
||||
FormMultipart = formMultipartBinding{}
|
||||
)
|
||||
|
||||
// Default get by binding type by method and contexttype.
|
||||
func Default(method, contentType string) Binding {
|
||||
if method == "GET" {
|
||||
return Form
|
||||
}
|
||||
|
||||
contentType = stripContentTypeParam(contentType)
|
||||
switch contentType {
|
||||
case MIMEJSON:
|
||||
return JSON
|
||||
case MIMEXML, MIMEXML2:
|
||||
return XML
|
||||
default: //case MIMEPOSTForm, MIMEMultipartPOSTForm:
|
||||
return Form
|
||||
}
|
||||
}
|
||||
|
||||
func validate(obj interface{}) error {
|
||||
if Validator == nil {
|
||||
return nil
|
||||
}
|
||||
return Validator.ValidateStruct(obj)
|
||||
}
|
||||
|
||||
func stripContentTypeParam(contentType string) string {
|
||||
i := strings.Index(contentType, ";")
|
||||
if i != -1 {
|
||||
contentType = contentType[:i]
|
||||
}
|
||||
return contentType
|
||||
}
|
342
pkg/net/http/blademaster/binding/binding_test.go
Normal file
342
pkg/net/http/blademaster/binding/binding_test.go
Normal file
@ -0,0 +1,342 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type FooStruct struct {
|
||||
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" validate:"required"`
|
||||
}
|
||||
|
||||
type FooBarStruct struct {
|
||||
FooStruct
|
||||
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" validate:"required"`
|
||||
Slice []string `form:"slice" validate:"max=10"`
|
||||
}
|
||||
|
||||
type ComplexDefaultStruct struct {
|
||||
Int int `form:"int" default:"999"`
|
||||
String string `form:"string" default:"default-string"`
|
||||
Bool bool `form:"bool" default:"false"`
|
||||
Int64Slice []int64 `form:"int64_slice,split" default:"1,2,3,4"`
|
||||
Int8Slice []int8 `form:"int8_slice,split" default:"1,2,3,4"`
|
||||
}
|
||||
|
||||
type Int8SliceStruct struct {
|
||||
State []int8 `form:"state,split"`
|
||||
}
|
||||
|
||||
type Int64SliceStruct struct {
|
||||
State []int64 `form:"state,split"`
|
||||
}
|
||||
|
||||
type StringSliceStruct struct {
|
||||
State []string `form:"state,split"`
|
||||
}
|
||||
|
||||
func TestBindingDefault(t *testing.T) {
|
||||
assert.Equal(t, Default("GET", ""), Form)
|
||||
assert.Equal(t, Default("GET", MIMEJSON), Form)
|
||||
assert.Equal(t, Default("GET", MIMEJSON+"; charset=utf-8"), Form)
|
||||
|
||||
assert.Equal(t, Default("POST", MIMEJSON), JSON)
|
||||
assert.Equal(t, Default("PUT", MIMEJSON), JSON)
|
||||
|
||||
assert.Equal(t, Default("POST", MIMEJSON+"; charset=utf-8"), JSON)
|
||||
assert.Equal(t, Default("PUT", MIMEJSON+"; charset=utf-8"), JSON)
|
||||
|
||||
assert.Equal(t, Default("POST", MIMEXML), XML)
|
||||
assert.Equal(t, Default("PUT", MIMEXML2), XML)
|
||||
|
||||
assert.Equal(t, Default("POST", MIMEPOSTForm), Form)
|
||||
assert.Equal(t, Default("PUT", MIMEPOSTForm), Form)
|
||||
|
||||
assert.Equal(t, Default("POST", MIMEPOSTForm+"; charset=utf-8"), Form)
|
||||
assert.Equal(t, Default("PUT", MIMEPOSTForm+"; charset=utf-8"), Form)
|
||||
|
||||
assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), Form)
|
||||
assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), Form)
|
||||
|
||||
}
|
||||
|
||||
func TestStripContentType(t *testing.T) {
|
||||
c1 := "application/vnd.mozilla.xul+xml"
|
||||
c2 := "application/vnd.mozilla.xul+xml; charset=utf-8"
|
||||
assert.Equal(t, stripContentTypeParam(c1), c1)
|
||||
assert.Equal(t, stripContentTypeParam(c2), "application/vnd.mozilla.xul+xml")
|
||||
}
|
||||
|
||||
func TestBindInt8Form(t *testing.T) {
|
||||
params := "state=1,2,3"
|
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q := new(Int8SliceStruct)
|
||||
Form.Bind(req, q)
|
||||
assert.EqualValues(t, []int8{1, 2, 3}, q.State)
|
||||
|
||||
params = "state=1,2,3,256"
|
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q = new(Int8SliceStruct)
|
||||
assert.Error(t, Form.Bind(req, q))
|
||||
|
||||
params = "state="
|
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q = new(Int8SliceStruct)
|
||||
assert.NoError(t, Form.Bind(req, q))
|
||||
assert.Len(t, q.State, 0)
|
||||
|
||||
params = "state=1,,2"
|
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q = new(Int8SliceStruct)
|
||||
assert.NoError(t, Form.Bind(req, q))
|
||||
assert.EqualValues(t, []int8{1, 2}, q.State)
|
||||
}
|
||||
|
||||
func TestBindInt64Form(t *testing.T) {
|
||||
params := "state=1,2,3"
|
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q := new(Int64SliceStruct)
|
||||
Form.Bind(req, q)
|
||||
assert.EqualValues(t, []int64{1, 2, 3}, q.State)
|
||||
|
||||
params = "state="
|
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q = new(Int64SliceStruct)
|
||||
assert.NoError(t, Form.Bind(req, q))
|
||||
assert.Len(t, q.State, 0)
|
||||
}
|
||||
|
||||
func TestBindStringForm(t *testing.T) {
|
||||
params := "state=1,2,3"
|
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q := new(StringSliceStruct)
|
||||
Form.Bind(req, q)
|
||||
assert.EqualValues(t, []string{"1", "2", "3"}, q.State)
|
||||
|
||||
params = "state="
|
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q = new(StringSliceStruct)
|
||||
assert.NoError(t, Form.Bind(req, q))
|
||||
assert.Len(t, q.State, 0)
|
||||
|
||||
params = "state=p,,p"
|
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q = new(StringSliceStruct)
|
||||
Form.Bind(req, q)
|
||||
assert.EqualValues(t, []string{"p", "p"}, q.State)
|
||||
}
|
||||
|
||||
func TestBindingJSON(t *testing.T) {
|
||||
testBodyBinding(t,
|
||||
JSON, "json",
|
||||
"/", "/",
|
||||
`{"foo": "bar"}`, `{"bar": "foo"}`)
|
||||
}
|
||||
|
||||
func TestBindingForm(t *testing.T) {
|
||||
testFormBinding(t, "POST",
|
||||
"/", "/",
|
||||
"foo=bar&bar=foo&slice=a&slice=b", "bar2=foo")
|
||||
}
|
||||
|
||||
func TestBindingForm2(t *testing.T) {
|
||||
testFormBinding(t, "GET",
|
||||
"/?foo=bar&bar=foo", "/?bar2=foo",
|
||||
"", "")
|
||||
}
|
||||
|
||||
func TestBindingQuery(t *testing.T) {
|
||||
testQueryBinding(t, "POST",
|
||||
"/?foo=bar&bar=foo", "/",
|
||||
"foo=unused", "bar2=foo")
|
||||
}
|
||||
|
||||
func TestBindingQuery2(t *testing.T) {
|
||||
testQueryBinding(t, "GET",
|
||||
"/?foo=bar&bar=foo", "/?bar2=foo",
|
||||
"foo=unused", "")
|
||||
}
|
||||
|
||||
func TestBindingXML(t *testing.T) {
|
||||
testBodyBinding(t,
|
||||
XML, "xml",
|
||||
"/", "/",
|
||||
"<map><foo>bar</foo></map>", "<map><bar>foo</bar></map>")
|
||||
}
|
||||
|
||||
func createFormPostRequest() *http.Request {
|
||||
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo"))
|
||||
req.Header.Set("Content-Type", MIMEPOSTForm)
|
||||
return req
|
||||
}
|
||||
|
||||
func createFormMultipartRequest() *http.Request {
|
||||
boundary := "--testboundary"
|
||||
body := new(bytes.Buffer)
|
||||
mw := multipart.NewWriter(body)
|
||||
defer mw.Close()
|
||||
|
||||
mw.SetBoundary(boundary)
|
||||
mw.WriteField("foo", "bar")
|
||||
mw.WriteField("bar", "foo")
|
||||
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body)
|
||||
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary)
|
||||
return req
|
||||
}
|
||||
|
||||
func TestBindingFormPost(t *testing.T) {
|
||||
req := createFormPostRequest()
|
||||
var obj FooBarStruct
|
||||
FormPost.Bind(req, &obj)
|
||||
|
||||
assert.Equal(t, obj.Foo, "bar")
|
||||
assert.Equal(t, obj.Bar, "foo")
|
||||
}
|
||||
|
||||
func TestBindingFormMultipart(t *testing.T) {
|
||||
req := createFormMultipartRequest()
|
||||
var obj FooBarStruct
|
||||
FormMultipart.Bind(req, &obj)
|
||||
|
||||
assert.Equal(t, obj.Foo, "bar")
|
||||
assert.Equal(t, obj.Bar, "foo")
|
||||
}
|
||||
|
||||
func TestValidationFails(t *testing.T) {
|
||||
var obj FooStruct
|
||||
req := requestWithBody("POST", "/", `{"bar": "foo"}`)
|
||||
err := JSON.Bind(req, &obj)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidationDisabled(t *testing.T) {
|
||||
backup := Validator
|
||||
Validator = nil
|
||||
defer func() { Validator = backup }()
|
||||
|
||||
var obj FooStruct
|
||||
req := requestWithBody("POST", "/", `{"bar": "foo"}`)
|
||||
err := JSON.Bind(req, &obj)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestExistsSucceeds(t *testing.T) {
|
||||
type HogeStruct struct {
|
||||
Hoge *int `json:"hoge" binding:"exists"`
|
||||
}
|
||||
|
||||
var obj HogeStruct
|
||||
req := requestWithBody("POST", "/", `{"hoge": 0}`)
|
||||
err := JSON.Bind(req, &obj)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestFormDefaultValue(t *testing.T) {
|
||||
params := "int=333&string=hello&bool=true&int64_slice=5,6,7,8&int8_slice=5,6,7,8"
|
||||
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q := new(ComplexDefaultStruct)
|
||||
assert.NoError(t, Form.Bind(req, q))
|
||||
assert.Equal(t, 333, q.Int)
|
||||
assert.Equal(t, "hello", q.String)
|
||||
assert.Equal(t, true, q.Bool)
|
||||
assert.EqualValues(t, []int64{5, 6, 7, 8}, q.Int64Slice)
|
||||
assert.EqualValues(t, []int8{5, 6, 7, 8}, q.Int8Slice)
|
||||
|
||||
params = "string=hello&bool=false"
|
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q = new(ComplexDefaultStruct)
|
||||
assert.NoError(t, Form.Bind(req, q))
|
||||
assert.Equal(t, 999, q.Int)
|
||||
assert.Equal(t, "hello", q.String)
|
||||
assert.Equal(t, false, q.Bool)
|
||||
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
|
||||
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
|
||||
|
||||
params = "strings=hello"
|
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q = new(ComplexDefaultStruct)
|
||||
assert.NoError(t, Form.Bind(req, q))
|
||||
assert.Equal(t, 999, q.Int)
|
||||
assert.Equal(t, "default-string", q.String)
|
||||
assert.Equal(t, false, q.Bool)
|
||||
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
|
||||
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
|
||||
|
||||
params = "int=&string=&bool=true&int64_slice=&int8_slice="
|
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
q = new(ComplexDefaultStruct)
|
||||
assert.NoError(t, Form.Bind(req, q))
|
||||
assert.Equal(t, 999, q.Int)
|
||||
assert.Equal(t, "default-string", q.String)
|
||||
assert.Equal(t, true, q.Bool)
|
||||
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
|
||||
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
|
||||
}
|
||||
|
||||
func testFormBinding(t *testing.T, method, path, badPath, body, badBody string) {
|
||||
b := Form
|
||||
assert.Equal(t, b.Name(), "form")
|
||||
|
||||
obj := FooBarStruct{}
|
||||
req := requestWithBody(method, path, body)
|
||||
if method == "POST" {
|
||||
req.Header.Add("Content-Type", MIMEPOSTForm)
|
||||
}
|
||||
err := b.Bind(req, &obj)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, obj.Foo, "bar")
|
||||
assert.Equal(t, obj.Bar, "foo")
|
||||
|
||||
obj = FooBarStruct{}
|
||||
req = requestWithBody(method, badPath, badBody)
|
||||
err = JSON.Bind(req, &obj)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func testQueryBinding(t *testing.T, method, path, badPath, body, badBody string) {
|
||||
b := Query
|
||||
assert.Equal(t, b.Name(), "query")
|
||||
|
||||
obj := FooBarStruct{}
|
||||
req := requestWithBody(method, path, body)
|
||||
if method == "POST" {
|
||||
req.Header.Add("Content-Type", MIMEPOSTForm)
|
||||
}
|
||||
err := b.Bind(req, &obj)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, obj.Foo, "bar")
|
||||
assert.Equal(t, obj.Bar, "foo")
|
||||
}
|
||||
|
||||
func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
|
||||
assert.Equal(t, b.Name(), name)
|
||||
|
||||
obj := FooStruct{}
|
||||
req := requestWithBody("POST", path, body)
|
||||
err := b.Bind(req, &obj)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, obj.Foo, "bar")
|
||||
|
||||
obj = FooStruct{}
|
||||
req = requestWithBody("POST", badPath, badBody)
|
||||
err = JSON.Bind(req, &obj)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func requestWithBody(method, path, body string) (req *http.Request) {
|
||||
req, _ = http.NewRequest(method, path, bytes.NewBufferString(body))
|
||||
return
|
||||
}
|
||||
func BenchmarkBindingForm(b *testing.B) {
|
||||
req := requestWithBody("POST", "/", "foo=bar&bar=foo&slice=a&slice=b&slice=c&slice=w")
|
||||
req.Header.Add("Content-Type", MIMEPOSTForm)
|
||||
f := Form
|
||||
for i := 0; i < b.N; i++ {
|
||||
obj := FooBarStruct{}
|
||||
f.Bind(req, &obj)
|
||||
}
|
||||
}
|
45
pkg/net/http/blademaster/binding/default_validator.go
Normal file
45
pkg/net/http/blademaster/binding/default_validator.go
Normal file
@ -0,0 +1,45 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
|
||||
type defaultValidator struct {
|
||||
once sync.Once
|
||||
validate *validator.Validate
|
||||
}
|
||||
|
||||
var _ StructValidator = &defaultValidator{}
|
||||
|
||||
func (v *defaultValidator) ValidateStruct(obj interface{}) error {
|
||||
if kindOfData(obj) == reflect.Struct {
|
||||
v.lazyinit()
|
||||
if err := v.validate.Struct(obj); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *defaultValidator) RegisterValidation(key string, fn validator.Func) error {
|
||||
v.lazyinit()
|
||||
return v.validate.RegisterValidation(key, fn)
|
||||
}
|
||||
|
||||
func (v *defaultValidator) lazyinit() {
|
||||
v.once.Do(func() {
|
||||
v.validate = validator.New()
|
||||
})
|
||||
}
|
||||
|
||||
func kindOfData(data interface{}) reflect.Kind {
|
||||
value := reflect.ValueOf(data)
|
||||
valueType := value.Kind()
|
||||
if valueType == reflect.Ptr {
|
||||
valueType = value.Elem().Kind()
|
||||
}
|
||||
return valueType
|
||||
}
|
113
pkg/net/http/blademaster/binding/example/test.pb.go
Normal file
113
pkg/net/http/blademaster/binding/example/test.pb.go
Normal file
@ -0,0 +1,113 @@
|
||||
// Code generated by protoc-gen-go.
|
||||
// source: test.proto
|
||||
// DO NOT EDIT!
|
||||
|
||||
/*
|
||||
Package example is a generated protocol buffer package.
|
||||
|
||||
It is generated from these files:
|
||||
test.proto
|
||||
|
||||
It has these top-level messages:
|
||||
Test
|
||||
*/
|
||||
package example
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import math "math"
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = math.Inf
|
||||
|
||||
type FOO int32
|
||||
|
||||
const (
|
||||
FOO_X FOO = 17
|
||||
)
|
||||
|
||||
var FOO_name = map[int32]string{
|
||||
17: "X",
|
||||
}
|
||||
var FOO_value = map[string]int32{
|
||||
"X": 17,
|
||||
}
|
||||
|
||||
func (x FOO) Enum() *FOO {
|
||||
p := new(FOO)
|
||||
*p = x
|
||||
return p
|
||||
}
|
||||
func (x FOO) String() string {
|
||||
return proto.EnumName(FOO_name, int32(x))
|
||||
}
|
||||
func (x *FOO) UnmarshalJSON(data []byte) error {
|
||||
value, err := proto.UnmarshalJSONEnum(FOO_value, data, "FOO")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*x = FOO(value)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Test struct {
|
||||
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
|
||||
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
|
||||
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
|
||||
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
}
|
||||
|
||||
func (m *Test) Reset() { *m = Test{} }
|
||||
func (m *Test) String() string { return proto.CompactTextString(m) }
|
||||
func (*Test) ProtoMessage() {}
|
||||
|
||||
const Default_Test_Type int32 = 77
|
||||
|
||||
func (m *Test) GetLabel() string {
|
||||
if m != nil && m.Label != nil {
|
||||
return *m.Label
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Test) GetType() int32 {
|
||||
if m != nil && m.Type != nil {
|
||||
return *m.Type
|
||||
}
|
||||
return Default_Test_Type
|
||||
}
|
||||
|
||||
func (m *Test) GetReps() []int64 {
|
||||
if m != nil {
|
||||
return m.Reps
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
|
||||
if m != nil {
|
||||
return m.Optionalgroup
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Test_OptionalGroup struct {
|
||||
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
}
|
||||
|
||||
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
|
||||
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
|
||||
func (*Test_OptionalGroup) ProtoMessage() {}
|
||||
|
||||
func (m *Test_OptionalGroup) GetRequiredField() string {
|
||||
if m != nil && m.RequiredField != nil {
|
||||
return *m.RequiredField
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
|
||||
}
|
12
pkg/net/http/blademaster/binding/example/test.proto
Normal file
12
pkg/net/http/blademaster/binding/example/test.proto
Normal file
@ -0,0 +1,12 @@
|
||||
package example;
|
||||
|
||||
enum FOO {X=17;};
|
||||
|
||||
message Test {
|
||||
required string label = 1;
|
||||
optional int32 type = 2[default=77];
|
||||
repeated int64 reps = 3;
|
||||
optional group OptionalGroup = 4{
|
||||
required string RequiredField = 5;
|
||||
}
|
||||
}
|
36
pkg/net/http/blademaster/binding/example_test.go
Normal file
36
pkg/net/http/blademaster/binding/example_test.go
Normal file
@ -0,0 +1,36 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Arg struct {
|
||||
Max int64 `form:"max" validate:"max=10"`
|
||||
Min int64 `form:"min" validate:"min=2"`
|
||||
Range int64 `form:"range" validate:"min=1,max=10"`
|
||||
// use split option to split arg 1,2,3 into slice [1 2 3]
|
||||
// otherwise slice type with parse url.Values (eg:a=b&a=c) default.
|
||||
Slice []int64 `form:"slice,split" validate:"min=1"`
|
||||
}
|
||||
|
||||
func ExampleBinding() {
|
||||
req := initHTTP("max=9&min=3&range=3&slice=1,2,3")
|
||||
arg := new(Arg)
|
||||
if err := Form.Bind(req, arg); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("arg.Max %d\narg.Min %d\narg.Range %d\narg.Slice %v", arg.Max, arg.Min, arg.Range, arg.Slice)
|
||||
// Output:
|
||||
// arg.Max 9
|
||||
// arg.Min 3
|
||||
// arg.Range 3
|
||||
// arg.Slice [1 2 3]
|
||||
}
|
||||
|
||||
func initHTTP(params string) (req *http.Request) {
|
||||
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
|
||||
req.ParseForm()
|
||||
return
|
||||
}
|
55
pkg/net/http/blademaster/binding/form.go
Normal file
55
pkg/net/http/blademaster/binding/form.go
Normal file
@ -0,0 +1,55 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const defaultMemory = 32 * 1024 * 1024
|
||||
|
||||
type formBinding struct{}
|
||||
type formPostBinding struct{}
|
||||
type formMultipartBinding struct{}
|
||||
|
||||
func (f formBinding) Name() string {
|
||||
return "form"
|
||||
}
|
||||
|
||||
func (f formBinding) Bind(req *http.Request, obj interface{}) error {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
if err := mapForm(obj, req.Form); err != nil {
|
||||
return err
|
||||
}
|
||||
return validate(obj)
|
||||
}
|
||||
|
||||
func (f formPostBinding) Name() string {
|
||||
return "form-urlencoded"
|
||||
}
|
||||
|
||||
func (f formPostBinding) Bind(req *http.Request, obj interface{}) error {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
if err := mapForm(obj, req.PostForm); err != nil {
|
||||
return err
|
||||
}
|
||||
return validate(obj)
|
||||
}
|
||||
|
||||
func (f formMultipartBinding) Name() string {
|
||||
return "multipart/form-data"
|
||||
}
|
||||
|
||||
func (f formMultipartBinding) Bind(req *http.Request, obj interface{}) error {
|
||||
if err := req.ParseMultipartForm(defaultMemory); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
if err := mapForm(obj, req.MultipartForm.Value); err != nil {
|
||||
return err
|
||||
}
|
||||
return validate(obj)
|
||||
}
|
276
pkg/net/http/blademaster/binding/form_mapping.go
Normal file
276
pkg/net/http/blademaster/binding/form_mapping.go
Normal file
@ -0,0 +1,276 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// scache struct reflect type cache.
|
||||
var scache = &cache{
|
||||
data: make(map[reflect.Type]*sinfo),
|
||||
}
|
||||
|
||||
type cache struct {
|
||||
data map[reflect.Type]*sinfo
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *cache) get(obj reflect.Type) (s *sinfo) {
|
||||
var ok bool
|
||||
c.mutex.RLock()
|
||||
if s, ok = c.data[obj]; !ok {
|
||||
c.mutex.RUnlock()
|
||||
s = c.set(obj)
|
||||
return
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (c *cache) set(obj reflect.Type) (s *sinfo) {
|
||||
s = new(sinfo)
|
||||
tp := obj.Elem()
|
||||
for i := 0; i < tp.NumField(); i++ {
|
||||
fd := new(field)
|
||||
fd.tp = tp.Field(i)
|
||||
tag := fd.tp.Tag.Get("form")
|
||||
fd.name, fd.option = parseTag(tag)
|
||||
if defV := fd.tp.Tag.Get("default"); defV != "" {
|
||||
dv := reflect.New(fd.tp.Type).Elem()
|
||||
setWithProperType(fd.tp.Type.Kind(), []string{defV}, dv, fd.option)
|
||||
fd.hasDefault = true
|
||||
fd.defaultValue = dv
|
||||
}
|
||||
s.field = append(s.field, fd)
|
||||
}
|
||||
c.mutex.Lock()
|
||||
c.data[obj] = s
|
||||
c.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
type sinfo struct {
|
||||
field []*field
|
||||
}
|
||||
|
||||
type field struct {
|
||||
tp reflect.StructField
|
||||
name string
|
||||
option tagOptions
|
||||
|
||||
hasDefault bool // if field had default value
|
||||
defaultValue reflect.Value // field default value
|
||||
}
|
||||
|
||||
func mapForm(ptr interface{}, form map[string][]string) error {
|
||||
sinfo := scache.get(reflect.TypeOf(ptr))
|
||||
val := reflect.ValueOf(ptr).Elem()
|
||||
for i, fd := range sinfo.field {
|
||||
typeField := fd.tp
|
||||
structField := val.Field(i)
|
||||
if !structField.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
structFieldKind := structField.Kind()
|
||||
inputFieldName := fd.name
|
||||
if inputFieldName == "" {
|
||||
inputFieldName = typeField.Name
|
||||
|
||||
// if "form" tag is nil, we inspect if the field is a struct.
|
||||
// this would not make sense for JSON parsing but it does for a form
|
||||
// since data is flatten
|
||||
if structFieldKind == reflect.Struct {
|
||||
err := mapForm(structField.Addr().Interface(), form)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
inputValue, exists := form[inputFieldName]
|
||||
if !exists {
|
||||
// Set the field as default value when the input value is not exist
|
||||
if fd.hasDefault {
|
||||
structField.Set(fd.defaultValue)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Set the field as default value when the input value is empty
|
||||
if fd.hasDefault && inputValue[0] == "" {
|
||||
structField.Set(fd.defaultValue)
|
||||
continue
|
||||
}
|
||||
if _, isTime := structField.Interface().(time.Time); isTime {
|
||||
if err := setTimeField(inputValue[0], typeField, structField); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := setWithProperType(typeField.Type.Kind(), inputValue, structField, fd.option); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setWithProperType(valueKind reflect.Kind, val []string, structField reflect.Value, option tagOptions) error {
|
||||
switch valueKind {
|
||||
case reflect.Int:
|
||||
return setIntField(val[0], 0, structField)
|
||||
case reflect.Int8:
|
||||
return setIntField(val[0], 8, structField)
|
||||
case reflect.Int16:
|
||||
return setIntField(val[0], 16, structField)
|
||||
case reflect.Int32:
|
||||
return setIntField(val[0], 32, structField)
|
||||
case reflect.Int64:
|
||||
return setIntField(val[0], 64, structField)
|
||||
case reflect.Uint:
|
||||
return setUintField(val[0], 0, structField)
|
||||
case reflect.Uint8:
|
||||
return setUintField(val[0], 8, structField)
|
||||
case reflect.Uint16:
|
||||
return setUintField(val[0], 16, structField)
|
||||
case reflect.Uint32:
|
||||
return setUintField(val[0], 32, structField)
|
||||
case reflect.Uint64:
|
||||
return setUintField(val[0], 64, structField)
|
||||
case reflect.Bool:
|
||||
return setBoolField(val[0], structField)
|
||||
case reflect.Float32:
|
||||
return setFloatField(val[0], 32, structField)
|
||||
case reflect.Float64:
|
||||
return setFloatField(val[0], 64, structField)
|
||||
case reflect.String:
|
||||
structField.SetString(val[0])
|
||||
case reflect.Slice:
|
||||
if option.Contains("split") {
|
||||
val = strings.Split(val[0], ",")
|
||||
}
|
||||
filtered := filterEmpty(val)
|
||||
switch structField.Type().Elem().Kind() {
|
||||
case reflect.Int64:
|
||||
valSli := make([]int64, 0, len(filtered))
|
||||
for i := 0; i < len(filtered); i++ {
|
||||
d, err := strconv.ParseInt(filtered[i], 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
valSli = append(valSli, d)
|
||||
}
|
||||
structField.Set(reflect.ValueOf(valSli))
|
||||
case reflect.String:
|
||||
valSli := make([]string, 0, len(filtered))
|
||||
for i := 0; i < len(filtered); i++ {
|
||||
valSli = append(valSli, filtered[i])
|
||||
}
|
||||
structField.Set(reflect.ValueOf(valSli))
|
||||
default:
|
||||
sliceOf := structField.Type().Elem().Kind()
|
||||
numElems := len(filtered)
|
||||
slice := reflect.MakeSlice(structField.Type(), len(filtered), len(filtered))
|
||||
for i := 0; i < numElems; i++ {
|
||||
if err := setWithProperType(sliceOf, filtered[i:], slice.Index(i), ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
structField.Set(slice)
|
||||
}
|
||||
default:
|
||||
return errors.New("Unknown type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setIntField(val string, bitSize int, field reflect.Value) error {
|
||||
if val == "" {
|
||||
val = "0"
|
||||
}
|
||||
intVal, err := strconv.ParseInt(val, 10, bitSize)
|
||||
if err == nil {
|
||||
field.SetInt(intVal)
|
||||
}
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
func setUintField(val string, bitSize int, field reflect.Value) error {
|
||||
if val == "" {
|
||||
val = "0"
|
||||
}
|
||||
uintVal, err := strconv.ParseUint(val, 10, bitSize)
|
||||
if err == nil {
|
||||
field.SetUint(uintVal)
|
||||
}
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
func setBoolField(val string, field reflect.Value) error {
|
||||
if val == "" {
|
||||
val = "false"
|
||||
}
|
||||
boolVal, err := strconv.ParseBool(val)
|
||||
if err == nil {
|
||||
field.SetBool(boolVal)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setFloatField(val string, bitSize int, field reflect.Value) error {
|
||||
if val == "" {
|
||||
val = "0.0"
|
||||
}
|
||||
floatVal, err := strconv.ParseFloat(val, bitSize)
|
||||
if err == nil {
|
||||
field.SetFloat(floatVal)
|
||||
}
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
func setTimeField(val string, structField reflect.StructField, value reflect.Value) error {
|
||||
timeFormat := structField.Tag.Get("time_format")
|
||||
if timeFormat == "" {
|
||||
return errors.New("Blank time format")
|
||||
}
|
||||
|
||||
if val == "" {
|
||||
value.Set(reflect.ValueOf(time.Time{}))
|
||||
return nil
|
||||
}
|
||||
|
||||
l := time.Local
|
||||
if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC {
|
||||
l = time.UTC
|
||||
}
|
||||
|
||||
if locTag := structField.Tag.Get("time_location"); locTag != "" {
|
||||
loc, err := time.LoadLocation(locTag)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
l = loc
|
||||
}
|
||||
|
||||
t, err := time.ParseInLocation(timeFormat, val, l)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
value.Set(reflect.ValueOf(t))
|
||||
return nil
|
||||
}
|
||||
|
||||
func filterEmpty(val []string) []string {
|
||||
filtered := make([]string, 0, len(val))
|
||||
for _, v := range val {
|
||||
if v != "" {
|
||||
filtered = append(filtered, v)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
22
pkg/net/http/blademaster/binding/json.go
Normal file
22
pkg/net/http/blademaster/binding/json.go
Normal file
@ -0,0 +1,22 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type jsonBinding struct{}
|
||||
|
||||
func (jsonBinding) Name() string {
|
||||
return "json"
|
||||
}
|
||||
|
||||
func (jsonBinding) Bind(req *http.Request, obj interface{}) error {
|
||||
decoder := json.NewDecoder(req.Body)
|
||||
if err := decoder.Decode(obj); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
return validate(obj)
|
||||
}
|
19
pkg/net/http/blademaster/binding/query.go
Normal file
19
pkg/net/http/blademaster/binding/query.go
Normal file
@ -0,0 +1,19 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type queryBinding struct{}
|
||||
|
||||
func (queryBinding) Name() string {
|
||||
return "query"
|
||||
}
|
||||
|
||||
func (queryBinding) Bind(req *http.Request, obj interface{}) error {
|
||||
values := req.URL.Query()
|
||||
if err := mapForm(obj, values); err != nil {
|
||||
return err
|
||||
}
|
||||
return validate(obj)
|
||||
}
|
44
pkg/net/http/blademaster/binding/tags.go
Normal file
44
pkg/net/http/blademaster/binding/tags.go
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package binding
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// tagOptions is the string following a comma in a struct field's "json"
|
||||
// tag, or the empty string. It does not include the leading comma.
|
||||
type tagOptions string
|
||||
|
||||
// parseTag splits a struct field's json tag into its name and
|
||||
// comma-separated options.
|
||||
func parseTag(tag string) (string, tagOptions) {
|
||||
if idx := strings.Index(tag, ","); idx != -1 {
|
||||
return tag[:idx], tagOptions(tag[idx+1:])
|
||||
}
|
||||
return tag, tagOptions("")
|
||||
}
|
||||
|
||||
// Contains reports whether a comma-separated list of options
|
||||
// contains a particular substr flag. substr must be surrounded by a
|
||||
// string boundary or commas.
|
||||
func (o tagOptions) Contains(optionName string) bool {
|
||||
if len(o) == 0 {
|
||||
return false
|
||||
}
|
||||
s := string(o)
|
||||
for s != "" {
|
||||
var next string
|
||||
i := strings.Index(s, ",")
|
||||
if i >= 0 {
|
||||
s, next = s[:i], s[i+1:]
|
||||
}
|
||||
if s == optionName {
|
||||
return true
|
||||
}
|
||||
s = next
|
||||
}
|
||||
return false
|
||||
}
|
209
pkg/net/http/blademaster/binding/validate_test.go
Normal file
209
pkg/net/http/blademaster/binding/validate_test.go
Normal file
@ -0,0 +1,209 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type testInterface interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
type substructNoValidation struct {
|
||||
IString string
|
||||
IInt int
|
||||
}
|
||||
|
||||
type mapNoValidationSub map[string]substructNoValidation
|
||||
|
||||
type structNoValidationValues struct {
|
||||
substructNoValidation
|
||||
|
||||
Boolean bool
|
||||
|
||||
Uinteger uint
|
||||
Integer int
|
||||
Integer8 int8
|
||||
Integer16 int16
|
||||
Integer32 int32
|
||||
Integer64 int64
|
||||
Uinteger8 uint8
|
||||
Uinteger16 uint16
|
||||
Uinteger32 uint32
|
||||
Uinteger64 uint64
|
||||
|
||||
Float32 float32
|
||||
Float64 float64
|
||||
|
||||
String string
|
||||
|
||||
Date time.Time
|
||||
|
||||
Struct substructNoValidation
|
||||
InlinedStruct struct {
|
||||
String []string
|
||||
Integer int
|
||||
}
|
||||
|
||||
IntSlice []int
|
||||
IntPointerSlice []*int
|
||||
StructPointerSlice []*substructNoValidation
|
||||
StructSlice []substructNoValidation
|
||||
InterfaceSlice []testInterface
|
||||
|
||||
UniversalInterface interface{}
|
||||
CustomInterface testInterface
|
||||
|
||||
FloatMap map[string]float32
|
||||
StructMap mapNoValidationSub
|
||||
}
|
||||
|
||||
func createNoValidationValues() structNoValidationValues {
|
||||
integer := 1
|
||||
s := structNoValidationValues{
|
||||
Boolean: true,
|
||||
Uinteger: 1 << 29,
|
||||
Integer: -10000,
|
||||
Integer8: 120,
|
||||
Integer16: -20000,
|
||||
Integer32: 1 << 29,
|
||||
Integer64: 1 << 61,
|
||||
Uinteger8: 250,
|
||||
Uinteger16: 50000,
|
||||
Uinteger32: 1 << 31,
|
||||
Uinteger64: 1 << 62,
|
||||
Float32: 123.456,
|
||||
Float64: 123.456789,
|
||||
String: "text",
|
||||
Date: time.Time{},
|
||||
CustomInterface: &bytes.Buffer{},
|
||||
Struct: substructNoValidation{},
|
||||
IntSlice: []int{-3, -2, 1, 0, 1, 2, 3},
|
||||
IntPointerSlice: []*int{&integer},
|
||||
StructSlice: []substructNoValidation{},
|
||||
UniversalInterface: 1.2,
|
||||
FloatMap: map[string]float32{
|
||||
"foo": 1.23,
|
||||
"bar": 232.323,
|
||||
},
|
||||
StructMap: mapNoValidationSub{
|
||||
"foo": substructNoValidation{},
|
||||
"bar": substructNoValidation{},
|
||||
},
|
||||
// StructPointerSlice []noValidationSub
|
||||
// InterfaceSlice []testInterface
|
||||
}
|
||||
s.InlinedStruct.Integer = 1000
|
||||
s.InlinedStruct.String = []string{"first", "second"}
|
||||
s.IString = "substring"
|
||||
s.IInt = 987654
|
||||
return s
|
||||
}
|
||||
|
||||
func TestValidateNoValidationValues(t *testing.T) {
|
||||
origin := createNoValidationValues()
|
||||
test := createNoValidationValues()
|
||||
empty := structNoValidationValues{}
|
||||
|
||||
assert.Nil(t, validate(test))
|
||||
assert.Nil(t, validate(&test))
|
||||
assert.Nil(t, validate(empty))
|
||||
assert.Nil(t, validate(&empty))
|
||||
|
||||
assert.Equal(t, origin, test)
|
||||
}
|
||||
|
||||
type structNoValidationPointer struct {
|
||||
// substructNoValidation
|
||||
|
||||
Boolean bool
|
||||
|
||||
Uinteger *uint
|
||||
Integer *int
|
||||
Integer8 *int8
|
||||
Integer16 *int16
|
||||
Integer32 *int32
|
||||
Integer64 *int64
|
||||
Uinteger8 *uint8
|
||||
Uinteger16 *uint16
|
||||
Uinteger32 *uint32
|
||||
Uinteger64 *uint64
|
||||
|
||||
Float32 *float32
|
||||
Float64 *float64
|
||||
|
||||
String *string
|
||||
|
||||
Date *time.Time
|
||||
|
||||
Struct *substructNoValidation
|
||||
|
||||
IntSlice *[]int
|
||||
IntPointerSlice *[]*int
|
||||
StructPointerSlice *[]*substructNoValidation
|
||||
StructSlice *[]substructNoValidation
|
||||
InterfaceSlice *[]testInterface
|
||||
|
||||
FloatMap *map[string]float32
|
||||
StructMap *mapNoValidationSub
|
||||
}
|
||||
|
||||
func TestValidateNoValidationPointers(t *testing.T) {
|
||||
//origin := createNoValidation_values()
|
||||
//test := createNoValidation_values()
|
||||
empty := structNoValidationPointer{}
|
||||
|
||||
//assert.Nil(t, validate(test))
|
||||
//assert.Nil(t, validate(&test))
|
||||
assert.Nil(t, validate(empty))
|
||||
assert.Nil(t, validate(&empty))
|
||||
|
||||
//assert.Equal(t, origin, test)
|
||||
}
|
||||
|
||||
type Object map[string]interface{}
|
||||
|
||||
func TestValidatePrimitives(t *testing.T) {
|
||||
obj := Object{"foo": "bar", "bar": 1}
|
||||
assert.NoError(t, validate(obj))
|
||||
assert.NoError(t, validate(&obj))
|
||||
assert.Equal(t, obj, Object{"foo": "bar", "bar": 1})
|
||||
|
||||
obj2 := []Object{{"foo": "bar", "bar": 1}, {"foo": "bar", "bar": 1}}
|
||||
assert.NoError(t, validate(obj2))
|
||||
assert.NoError(t, validate(&obj2))
|
||||
|
||||
nu := 10
|
||||
assert.NoError(t, validate(nu))
|
||||
assert.NoError(t, validate(&nu))
|
||||
assert.Equal(t, nu, 10)
|
||||
|
||||
str := "value"
|
||||
assert.NoError(t, validate(str))
|
||||
assert.NoError(t, validate(&str))
|
||||
assert.Equal(t, str, "value")
|
||||
}
|
||||
|
||||
// structCustomValidation is a helper struct we use to check that
|
||||
// custom validation can be registered on it.
|
||||
// The `notone` binding directive is for custom validation and registered later.
|
||||
// type structCustomValidation struct {
|
||||
// Integer int `binding:"notone"`
|
||||
// }
|
||||
|
||||
// notOne is a custom validator meant to be used with `validator.v8` library.
|
||||
// The method signature for `v9` is significantly different and this function
|
||||
// would need to be changed for tests to pass after upgrade.
|
||||
// See https://github.com/gin-gonic/gin/pull/1015.
|
||||
// func notOne(
|
||||
// v *validator.Validate, topStruct reflect.Value, currentStructOrField reflect.Value,
|
||||
// field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string,
|
||||
// ) bool {
|
||||
// if val, ok := field.Interface().(int); ok {
|
||||
// return val != 1
|
||||
// }
|
||||
// return false
|
||||
// }
|
22
pkg/net/http/blademaster/binding/xml.go
Normal file
22
pkg/net/http/blademaster/binding/xml.go
Normal file
@ -0,0 +1,22 @@
|
||||
package binding
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type xmlBinding struct{}
|
||||
|
||||
func (xmlBinding) Name() string {
|
||||
return "xml"
|
||||
}
|
||||
|
||||
func (xmlBinding) Bind(req *http.Request, obj interface{}) error {
|
||||
decoder := xml.NewDecoder(req.Body)
|
||||
if err := decoder.Decode(obj); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
return validate(obj)
|
||||
}
|
306
pkg/net/http/blademaster/context.go
Normal file
306
pkg/net/http/blademaster/context.go
Normal file
@ -0,0 +1,306 @@
|
||||
package blademaster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/bilibili/Kratos/pkg/ecode"
|
||||
"github.com/bilibili/Kratos/pkg/net/http/blademaster/binding"
|
||||
"github.com/bilibili/Kratos/pkg/net/http/blademaster/render"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/gogo/protobuf/types"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
_abortIndex int8 = math.MaxInt8 / 2
|
||||
)
|
||||
|
||||
var (
|
||||
_openParen = []byte("(")
|
||||
_closeParen = []byte(")")
|
||||
)
|
||||
|
||||
// Context is the most important part. It allows us to pass variables between
|
||||
// middleware, manage the flow, validate the JSON of a request and render a
|
||||
// JSON response for example.
|
||||
type Context struct {
|
||||
context.Context
|
||||
|
||||
Request *http.Request
|
||||
Writer http.ResponseWriter
|
||||
|
||||
// flow control
|
||||
index int8
|
||||
handlers []HandlerFunc
|
||||
|
||||
// Keys is a key/value pair exclusively for the context of each request.
|
||||
Keys map[string]interface{}
|
||||
|
||||
Error error
|
||||
|
||||
method string
|
||||
engine *Engine
|
||||
}
|
||||
|
||||
/************************************/
|
||||
/*********** FLOW CONTROL ***********/
|
||||
/************************************/
|
||||
|
||||
// Next should be used only inside middleware.
|
||||
// It executes the pending handlers in the chain inside the calling handler.
|
||||
// See example in godoc.
|
||||
func (c *Context) Next() {
|
||||
c.index++
|
||||
s := int8(len(c.handlers))
|
||||
for ; c.index < s; c.index++ {
|
||||
// only check method on last handler, otherwise middlewares
|
||||
// will never be effected if request method is not matched
|
||||
if c.index == s-1 && c.method != c.Request.Method {
|
||||
code := http.StatusMethodNotAllowed
|
||||
c.Error = ecode.MethodNotAllowed
|
||||
http.Error(c.Writer, http.StatusText(code), code)
|
||||
return
|
||||
}
|
||||
|
||||
c.handlers[c.index](c)
|
||||
}
|
||||
}
|
||||
|
||||
// Abort prevents pending handlers from being called. Note that this will not stop the current handler.
|
||||
// Let's say you have an authorization middleware that validates that the current request is authorized.
|
||||
// If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers
|
||||
// for this request are not called.
|
||||
func (c *Context) Abort() {
|
||||
c.index = _abortIndex
|
||||
}
|
||||
|
||||
// AbortWithStatus calls `Abort()` and writes the headers with the specified status code.
|
||||
// For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401).
|
||||
func (c *Context) AbortWithStatus(code int) {
|
||||
c.Status(code)
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// IsAborted returns true if the current context was aborted.
|
||||
func (c *Context) IsAborted() bool {
|
||||
return c.index >= _abortIndex
|
||||
}
|
||||
|
||||
/************************************/
|
||||
/******** METADATA MANAGEMENT********/
|
||||
/************************************/
|
||||
|
||||
// Set is used to store a new key/value pair exclusively for this context.
|
||||
// It also lazy initializes c.Keys if it was not used previously.
|
||||
func (c *Context) Set(key string, value interface{}) {
|
||||
if c.Keys == nil {
|
||||
c.Keys = make(map[string]interface{})
|
||||
}
|
||||
c.Keys[key] = value
|
||||
}
|
||||
|
||||
// Get returns the value for the given key, ie: (value, true).
|
||||
// If the value does not exists it returns (nil, false)
|
||||
func (c *Context) Get(key string) (value interface{}, exists bool) {
|
||||
value, exists = c.Keys[key]
|
||||
return
|
||||
}
|
||||
|
||||
/************************************/
|
||||
/******** RESPONSE RENDERING ********/
|
||||
/************************************/
|
||||
|
||||
// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function.
|
||||
func bodyAllowedForStatus(status int) bool {
|
||||
switch {
|
||||
case status >= 100 && status <= 199:
|
||||
return false
|
||||
case status == 204:
|
||||
return false
|
||||
case status == 304:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Status sets the HTTP response code.
|
||||
func (c *Context) Status(code int) {
|
||||
c.Writer.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Render http response with http code by a render instance.
|
||||
func (c *Context) Render(code int, r render.Render) {
|
||||
r.WriteContentType(c.Writer)
|
||||
if code > 0 {
|
||||
c.Status(code)
|
||||
}
|
||||
|
||||
if !bodyAllowedForStatus(code) {
|
||||
return
|
||||
}
|
||||
|
||||
params := c.Request.Form
|
||||
|
||||
cb := params.Get("callback")
|
||||
jsonp := cb != "" && params.Get("jsonp") == "jsonp"
|
||||
if jsonp {
|
||||
c.Writer.Write([]byte(cb))
|
||||
c.Writer.Write(_openParen)
|
||||
}
|
||||
|
||||
if err := r.Render(c.Writer); err != nil {
|
||||
c.Error = err
|
||||
return
|
||||
}
|
||||
|
||||
if jsonp {
|
||||
if _, err := c.Writer.Write(_closeParen); err != nil {
|
||||
c.Error = errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JSON serializes the given struct as JSON into the response body.
|
||||
// It also sets the Content-Type as "application/json".
|
||||
func (c *Context) JSON(data interface{}, err error) {
|
||||
code := http.StatusOK
|
||||
c.Error = err
|
||||
bcode := ecode.Cause(err)
|
||||
// TODO app allow 5xx?
|
||||
/*
|
||||
if bcode.Code() == -500 {
|
||||
code = http.StatusServiceUnavailable
|
||||
}
|
||||
*/
|
||||
writeStatusCode(c.Writer, bcode.Code())
|
||||
c.Render(code, render.JSON{
|
||||
Code: bcode.Code(),
|
||||
Message: bcode.Message(),
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// JSONMap serializes the given map as map JSON into the response body.
|
||||
// It also sets the Content-Type as "application/json".
|
||||
func (c *Context) JSONMap(data map[string]interface{}, err error) {
|
||||
code := http.StatusOK
|
||||
c.Error = err
|
||||
bcode := ecode.Cause(err)
|
||||
// TODO app allow 5xx?
|
||||
/*
|
||||
if bcode.Code() == -500 {
|
||||
code = http.StatusServiceUnavailable
|
||||
}
|
||||
*/
|
||||
writeStatusCode(c.Writer, bcode.Code())
|
||||
data["code"] = bcode.Code()
|
||||
if _, ok := data["message"]; !ok {
|
||||
data["message"] = bcode.Message()
|
||||
}
|
||||
c.Render(code, render.MapJSON(data))
|
||||
}
|
||||
|
||||
// XML serializes the given struct as XML into the response body.
|
||||
// It also sets the Content-Type as "application/xml".
|
||||
func (c *Context) XML(data interface{}, err error) {
|
||||
code := http.StatusOK
|
||||
c.Error = err
|
||||
bcode := ecode.Cause(err)
|
||||
// TODO app allow 5xx?
|
||||
/*
|
||||
if bcode.Code() == -500 {
|
||||
code = http.StatusServiceUnavailable
|
||||
}
|
||||
*/
|
||||
writeStatusCode(c.Writer, bcode.Code())
|
||||
c.Render(code, render.XML{
|
||||
Code: bcode.Code(),
|
||||
Message: bcode.Message(),
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Protobuf serializes the given struct as PB into the response body.
|
||||
// It also sets the ContentType as "application/x-protobuf".
|
||||
func (c *Context) Protobuf(data proto.Message, err error) {
|
||||
var (
|
||||
bytes []byte
|
||||
)
|
||||
|
||||
code := http.StatusOK
|
||||
c.Error = err
|
||||
bcode := ecode.Cause(err)
|
||||
|
||||
any := new(types.Any)
|
||||
if data != nil {
|
||||
if bytes, err = proto.Marshal(data); err != nil {
|
||||
c.Error = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
any.TypeUrl = "type.googleapis.com/" + proto.MessageName(data)
|
||||
any.Value = bytes
|
||||
}
|
||||
writeStatusCode(c.Writer, bcode.Code())
|
||||
c.Render(code, render.PB{
|
||||
Code: int64(bcode.Code()),
|
||||
Message: bcode.Message(),
|
||||
Data: any,
|
||||
})
|
||||
}
|
||||
|
||||
// Bytes writes some data into the body stream and updates the HTTP code.
|
||||
func (c *Context) Bytes(code int, contentType string, data ...[]byte) {
|
||||
c.Render(code, render.Data{
|
||||
ContentType: contentType,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// String writes the given string into the response body.
|
||||
func (c *Context) String(code int, format string, values ...interface{}) {
|
||||
c.Render(code, render.String{Format: format, Data: values})
|
||||
}
|
||||
|
||||
// Redirect returns a HTTP redirect to the specific location.
|
||||
func (c *Context) Redirect(code int, location string) {
|
||||
c.Render(-1, render.Redirect{
|
||||
Code: code,
|
||||
Location: location,
|
||||
Request: c.Request,
|
||||
})
|
||||
}
|
||||
|
||||
// BindWith bind req arg with parser.
|
||||
func (c *Context) BindWith(obj interface{}, b binding.Binding) error {
|
||||
return c.mustBindWith(obj, b)
|
||||
}
|
||||
|
||||
// Bind bind req arg with defult form binding.
|
||||
func (c *Context) Bind(obj interface{}) error {
|
||||
return c.mustBindWith(obj, binding.Form)
|
||||
}
|
||||
|
||||
// mustBindWith binds the passed struct pointer using the specified binding engine.
|
||||
// It will abort the request with HTTP 400 if any error ocurrs.
|
||||
// See the binding package.
|
||||
func (c *Context) mustBindWith(obj interface{}, b binding.Binding) (err error) {
|
||||
if err = b.Bind(c.Request, obj); err != nil {
|
||||
c.Error = ecode.RequestErr
|
||||
c.Render(http.StatusOK, render.JSON{
|
||||
Code: ecode.RequestErr.Code(),
|
||||
Message: err.Error(),
|
||||
Data: nil,
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func writeStatusCode(w http.ResponseWriter, ecode int) {
|
||||
header := w.Header()
|
||||
header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10))
|
||||
}
|
249
pkg/net/http/blademaster/cors.go
Normal file
249
pkg/net/http/blademaster/cors.go
Normal file
@ -0,0 +1,249 @@
|
||||
package blademaster
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/Kratos/pkg/log"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// CORSConfig represents all available options for the middleware.
|
||||
type CORSConfig struct {
|
||||
AllowAllOrigins bool
|
||||
|
||||
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
|
||||
// If the special "*" value is present in the list, all origins will be allowed.
|
||||
// Default value is []
|
||||
AllowOrigins []string
|
||||
|
||||
// AllowOriginFunc is a custom function to validate the origin. It take the origin
|
||||
// as argument and returns true if allowed or false otherwise. If this option is
|
||||
// set, the content of AllowedOrigins is ignored.
|
||||
AllowOriginFunc func(origin string) bool
|
||||
|
||||
// AllowedMethods is a list of methods the client is allowed to use with
|
||||
// cross-domain requests. Default value is simple methods (GET and POST)
|
||||
AllowMethods []string
|
||||
|
||||
// AllowedHeaders is list of non simple headers the client is allowed to use with
|
||||
// cross-domain requests.
|
||||
AllowHeaders []string
|
||||
|
||||
// AllowCredentials indicates whether the request can include user credentials like
|
||||
// cookies, HTTP authentication or client side SSL certificates.
|
||||
AllowCredentials bool
|
||||
|
||||
// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
|
||||
// API specification
|
||||
ExposeHeaders []string
|
||||
|
||||
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached
|
||||
MaxAge time.Duration
|
||||
}
|
||||
|
||||
type cors struct {
|
||||
allowAllOrigins bool
|
||||
allowCredentials bool
|
||||
allowOriginFunc func(string) bool
|
||||
allowOrigins []string
|
||||
normalHeaders http.Header
|
||||
preflightHeaders http.Header
|
||||
}
|
||||
|
||||
type converter func(string) string
|
||||
|
||||
// Validate is check configuration of user defined.
|
||||
func (c *CORSConfig) Validate() error {
|
||||
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
|
||||
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
|
||||
}
|
||||
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
|
||||
return errors.New("conflict settings: all origins disabled")
|
||||
}
|
||||
for _, origin := range c.AllowOrigins {
|
||||
if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
|
||||
return errors.New("bad origin: origins must either be '*' or include http:// or https://")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CORS returns the location middleware with default configuration.
|
||||
func CORS(allowOriginHosts []string) HandlerFunc {
|
||||
config := &CORSConfig{
|
||||
AllowMethods: []string{"GET", "POST"},
|
||||
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: time.Duration(0),
|
||||
AllowOriginFunc: func(origin string) bool {
|
||||
for _, host := range allowOriginHosts {
|
||||
if strings.HasSuffix(strings.ToLower(origin), host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
}
|
||||
return newCORS(config)
|
||||
}
|
||||
|
||||
// newCORS returns the location middleware with user-defined custom configuration.
|
||||
func newCORS(config *CORSConfig) HandlerFunc {
|
||||
if err := config.Validate(); err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
cors := &cors{
|
||||
allowOriginFunc: config.AllowOriginFunc,
|
||||
allowAllOrigins: config.AllowAllOrigins,
|
||||
allowCredentials: config.AllowCredentials,
|
||||
allowOrigins: normalize(config.AllowOrigins),
|
||||
normalHeaders: generateNormalHeaders(config),
|
||||
preflightHeaders: generatePreflightHeaders(config),
|
||||
}
|
||||
|
||||
return func(c *Context) {
|
||||
cors.applyCORS(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (cors *cors) applyCORS(c *Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if len(origin) == 0 {
|
||||
// request is not a CORS request
|
||||
return
|
||||
}
|
||||
if !cors.validateOrigin(origin) {
|
||||
log.V(5).Info("The request's Origin header `%s` does not match any of allowed origins.", origin)
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
cors.handlePreflight(c)
|
||||
defer c.AbortWithStatus(200)
|
||||
} else {
|
||||
cors.handleNormal(c)
|
||||
}
|
||||
|
||||
if !cors.allowAllOrigins {
|
||||
header := c.Writer.Header()
|
||||
header.Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
}
|
||||
|
||||
func (cors *cors) validateOrigin(origin string) bool {
|
||||
if cors.allowAllOrigins {
|
||||
return true
|
||||
}
|
||||
for _, value := range cors.allowOrigins {
|
||||
if value == origin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if cors.allowOriginFunc != nil {
|
||||
return cors.allowOriginFunc(origin)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (cors *cors) handlePreflight(c *Context) {
|
||||
header := c.Writer.Header()
|
||||
for key, value := range cors.preflightHeaders {
|
||||
header[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func (cors *cors) handleNormal(c *Context) {
|
||||
header := c.Writer.Header()
|
||||
for key, value := range cors.normalHeaders {
|
||||
header[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func generateNormalHeaders(c *CORSConfig) http.Header {
|
||||
headers := make(http.Header)
|
||||
if c.AllowCredentials {
|
||||
headers.Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
// backport support for early browsers
|
||||
if len(c.AllowMethods) > 0 {
|
||||
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
|
||||
value := strings.Join(allowMethods, ",")
|
||||
headers.Set("Access-Control-Allow-Methods", value)
|
||||
}
|
||||
|
||||
if len(c.ExposeHeaders) > 0 {
|
||||
exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey)
|
||||
headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ","))
|
||||
}
|
||||
if c.AllowAllOrigins {
|
||||
headers.Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
headers.Set("Vary", "Origin")
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func generatePreflightHeaders(c *CORSConfig) http.Header {
|
||||
headers := make(http.Header)
|
||||
if c.AllowCredentials {
|
||||
headers.Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
if len(c.AllowMethods) > 0 {
|
||||
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
|
||||
value := strings.Join(allowMethods, ",")
|
||||
headers.Set("Access-Control-Allow-Methods", value)
|
||||
}
|
||||
if len(c.AllowHeaders) > 0 {
|
||||
allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey)
|
||||
value := strings.Join(allowHeaders, ",")
|
||||
headers.Set("Access-Control-Allow-Headers", value)
|
||||
}
|
||||
if c.MaxAge > time.Duration(0) {
|
||||
value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10)
|
||||
headers.Set("Access-Control-Max-Age", value)
|
||||
}
|
||||
if c.AllowAllOrigins {
|
||||
headers.Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
// Always set Vary headers
|
||||
// see https://github.com/rs/cors/issues/10,
|
||||
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
|
||||
|
||||
headers.Add("Vary", "Origin")
|
||||
headers.Add("Vary", "Access-Control-Request-Method")
|
||||
headers.Add("Vary", "Access-Control-Request-Headers")
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func normalize(values []string) []string {
|
||||
if values == nil {
|
||||
return nil
|
||||
}
|
||||
distinctMap := make(map[string]bool, len(values))
|
||||
normalized := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
value = strings.ToLower(value)
|
||||
if _, seen := distinctMap[value]; !seen {
|
||||
normalized = append(normalized, value)
|
||||
distinctMap[value] = true
|
||||
}
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func convert(s []string, c converter) []string {
|
||||
var out []string
|
||||
for _, i := range s {
|
||||
out = append(out, c(i))
|
||||
}
|
||||
return out
|
||||
}
|
69
pkg/net/http/blademaster/csrf.go
Normal file
69
pkg/net/http/blademaster/csrf.go
Normal file
@ -0,0 +1,69 @@
|
||||
package blademaster
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/bilibili/Kratos/pkg/log"
|
||||
)
|
||||
|
||||
func matchHostSuffix(suffix string) func(*url.URL) bool {
|
||||
return func(uri *url.URL) bool {
|
||||
return strings.HasSuffix(strings.ToLower(uri.Host), suffix)
|
||||
}
|
||||
}
|
||||
|
||||
func matchPattern(pattern *regexp.Regexp) func(*url.URL) bool {
|
||||
return func(uri *url.URL) bool {
|
||||
return pattern.MatchString(strings.ToLower(uri.String()))
|
||||
}
|
||||
}
|
||||
|
||||
// CSRF returns the csrf middleware to prevent invalid cross site request.
|
||||
// Only referer is checked currently.
|
||||
func CSRF(allowHosts []string, allowPattern []string) HandlerFunc {
|
||||
validations := []func(*url.URL) bool{}
|
||||
|
||||
addHostSuffix := func(suffix string) {
|
||||
validations = append(validations, matchHostSuffix(suffix))
|
||||
}
|
||||
addPattern := func(pattern string) {
|
||||
validations = append(validations, matchPattern(regexp.MustCompile(pattern)))
|
||||
}
|
||||
|
||||
for _, r := range allowHosts {
|
||||
addHostSuffix(r)
|
||||
}
|
||||
for _, p := range allowPattern {
|
||||
addPattern(p)
|
||||
}
|
||||
|
||||
return func(c *Context) {
|
||||
referer := c.Request.Header.Get("Referer")
|
||||
params := c.Request.Form
|
||||
cross := (params.Get("callback") != "" && params.Get("jsonp") == "jsonp") || (params.Get("cross_domain") != "")
|
||||
if referer == "" {
|
||||
if !cross {
|
||||
return
|
||||
}
|
||||
log.V(5).Info("The request's Referer header is empty.")
|
||||
c.AbortWithStatus(403)
|
||||
return
|
||||
}
|
||||
illegal := true
|
||||
if uri, err := url.Parse(referer); err == nil && uri.Host != "" {
|
||||
for _, validate := range validations {
|
||||
if validate(uri) {
|
||||
illegal = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if illegal {
|
||||
log.V(5).Info("The request's Referer header `%s` does not match any of allowed referers.", referer)
|
||||
c.AbortWithStatus(403)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
69
pkg/net/http/blademaster/logger.go
Normal file
69
pkg/net/http/blademaster/logger.go
Normal file
@ -0,0 +1,69 @@
|
||||
package blademaster
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/Kratos/pkg/ecode"
|
||||
"github.com/bilibili/Kratos/pkg/log"
|
||||
"github.com/bilibili/Kratos/pkg/net/metadata"
|
||||
)
|
||||
|
||||
// Logger is logger middleware
|
||||
func Logger() HandlerFunc {
|
||||
const noUser = "no_user"
|
||||
return func(c *Context) {
|
||||
now := time.Now()
|
||||
ip := metadata.String(c, metadata.RemoteIP)
|
||||
req := c.Request
|
||||
path := req.URL.Path
|
||||
params := req.Form
|
||||
var quota float64
|
||||
if deadline, ok := c.Context.Deadline(); ok {
|
||||
quota = time.Until(deadline).Seconds()
|
||||
}
|
||||
|
||||
c.Next()
|
||||
|
||||
err := c.Error
|
||||
cerr := ecode.Cause(err)
|
||||
dt := time.Since(now)
|
||||
caller := metadata.String(c, metadata.Caller)
|
||||
if caller == "" {
|
||||
caller = noUser
|
||||
}
|
||||
|
||||
stats.Incr(caller, path[1:], strconv.FormatInt(int64(cerr.Code()), 10))
|
||||
stats.Timing(caller, int64(dt/time.Millisecond), path[1:])
|
||||
|
||||
lf := log.Infov
|
||||
errmsg := ""
|
||||
isSlow := dt >= (time.Millisecond * 500)
|
||||
if err != nil {
|
||||
errmsg = err.Error()
|
||||
lf = log.Errorv
|
||||
if cerr.Code() > 0 {
|
||||
lf = log.Warnv
|
||||
}
|
||||
} else {
|
||||
if isSlow {
|
||||
lf = log.Warnv
|
||||
}
|
||||
}
|
||||
lf(c,
|
||||
log.KVString("method", req.Method),
|
||||
log.KVString("ip", ip),
|
||||
log.KVString("user", caller),
|
||||
log.KVString("path", path),
|
||||
log.KVString("params", params.Encode()),
|
||||
log.KVInt("ret", cerr.Code()),
|
||||
log.KVString("msg", cerr.Message()),
|
||||
log.KVString("stack", fmt.Sprintf("%+v", err)),
|
||||
log.KVString("err", errmsg),
|
||||
log.KVFloat64("timeout_quota", quota),
|
||||
log.KVFloat64("ts", dt.Seconds()),
|
||||
log.KVString("source", "http-access-log"),
|
||||
)
|
||||
}
|
||||
}
|
@ -17,13 +17,14 @@ const (
|
||||
_httpHeaderUser = "x1-bmspy-user"
|
||||
_httpHeaderColor = "x1-bmspy-color"
|
||||
_httpHeaderTimeout = "x1-bmspy-timeout"
|
||||
_httpHeaderMirror = "x1-bmspy-mirror"
|
||||
_httpHeaderRemoteIP = "x-backend-bm-real-ip"
|
||||
_httpHeaderRemoteIPPort = "x-backend-bm-real-ipport"
|
||||
)
|
||||
|
||||
// mirror return true if x1-bilispy-mirror in http header and its value is 1 or true.
|
||||
// mirror return true if x-bmspy-mirror in http header and its value is 1 or true.
|
||||
func mirror(req *http.Request) bool {
|
||||
mirrorStr := req.Header.Get("x1-bilispy-mirror")
|
||||
mirrorStr := req.Header.Get(_httpHeaderMirror)
|
||||
if mirrorStr == "" {
|
||||
return false
|
||||
}
|
||||
@ -79,7 +80,7 @@ func timeout(req *http.Request) time.Duration {
|
||||
}
|
||||
|
||||
// remoteIP implements a best effort algorithm to return the real client IP, it parses
|
||||
// X-BACKEND-BILI-REAL-IP or X-Real-IP or X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
|
||||
// x-backend-bm-real-ip or X-Real-IP or X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
|
||||
// Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP.
|
||||
func remoteIP(req *http.Request) (remote string) {
|
||||
if remote = req.Header.Get(_httpHeaderRemoteIP); remote != "" && remote != "null" {
|
||||
|
46
pkg/net/http/blademaster/perf.go
Normal file
46
pkg/net/http/blademaster/perf.go
Normal file
@ -0,0 +1,46 @@
|
||||
package blademaster
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/bilibili/Kratos/pkg/conf/dsn"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
_perfOnce sync.Once
|
||||
_perfDSN string
|
||||
)
|
||||
|
||||
func init() {
|
||||
v := os.Getenv("HTTP_PERF")
|
||||
if v == "" {
|
||||
v = "tcp://0.0.0.0:2333"
|
||||
}
|
||||
flag.StringVar(&_perfDSN, "http.perf", v, "listen http perf dsn, or use HTTP_PERF env variable.")
|
||||
}
|
||||
|
||||
func startPerf() {
|
||||
_perfOnce.Do(func() {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
||||
|
||||
go func() {
|
||||
d, err := dsn.Parse(_perfDSN)
|
||||
if err != nil {
|
||||
panic(errors.Errorf("blademaster: http perf dsn must be tcp://$host:port, %s:error(%v)", _perfDSN, err))
|
||||
}
|
||||
if err := http.ListenAndServe(d.Host, mux); err != nil {
|
||||
panic(errors.Errorf("blademaster: listen %s: error(%v)", d.Host, err))
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
12
pkg/net/http/blademaster/prometheus.go
Normal file
12
pkg/net/http/blademaster/prometheus.go
Normal file
@ -0,0 +1,12 @@
|
||||
package blademaster
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
func monitor() HandlerFunc {
|
||||
return func(c *Context) {
|
||||
h := promhttp.Handler()
|
||||
h.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
}
|
32
pkg/net/http/blademaster/recovery.go
Normal file
32
pkg/net/http/blademaster/recovery.go
Normal file
@ -0,0 +1,32 @@
|
||||
package blademaster
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http/httputil"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/bilibili/Kratos/pkg/log"
|
||||
)
|
||||
|
||||
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
|
||||
func Recovery() HandlerFunc {
|
||||
return func(c *Context) {
|
||||
defer func() {
|
||||
var rawReq []byte
|
||||
if err := recover(); err != nil {
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
if c.Request != nil {
|
||||
rawReq, _ = httputil.DumpRequest(c.Request, false)
|
||||
}
|
||||
pl := fmt.Sprintf("http call panic: %s\n%v\n%s\n", string(rawReq), err, buf)
|
||||
fmt.Fprintf(os.Stderr, pl)
|
||||
log.Error(pl)
|
||||
c.AbortWithStatus(500)
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
30
pkg/net/http/blademaster/render/data.go
Normal file
30
pkg/net/http/blademaster/render/data.go
Normal file
@ -0,0 +1,30 @@
|
||||
package render
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Data common bytes struct.
|
||||
type Data struct {
|
||||
ContentType string
|
||||
Data [][]byte
|
||||
}
|
||||
|
||||
// Render (Data) writes data with custom ContentType.
|
||||
func (r Data) Render(w http.ResponseWriter) (err error) {
|
||||
r.WriteContentType(w)
|
||||
for _, d := range r.Data {
|
||||
if _, err = w.Write(d); err != nil {
|
||||
err = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// WriteContentType writes data with custom ContentType.
|
||||
func (r Data) WriteContentType(w http.ResponseWriter) {
|
||||
writeContentType(w, []string{r.ContentType})
|
||||
}
|
58
pkg/net/http/blademaster/render/json.go
Normal file
58
pkg/net/http/blademaster/render/json.go
Normal file
@ -0,0 +1,58 @@
|
||||
package render
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var jsonContentType = []string{"application/json; charset=utf-8"}
|
||||
|
||||
// JSON common json struct.
|
||||
type JSON struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
TTL int `json:"ttl"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, obj interface{}) (err error) {
|
||||
var jsonBytes []byte
|
||||
writeContentType(w, jsonContentType)
|
||||
if jsonBytes, err = json.Marshal(obj); err != nil {
|
||||
err = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
if _, err = w.Write(jsonBytes); err != nil {
|
||||
err = errors.WithStack(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Render (JSON) writes data with json ContentType.
|
||||
func (r JSON) Render(w http.ResponseWriter) error {
|
||||
// FIXME(zhoujiahui): the TTL field will be configurable in the future
|
||||
if r.TTL <= 0 {
|
||||
r.TTL = 1
|
||||
}
|
||||
return writeJSON(w, r)
|
||||
}
|
||||
|
||||
// WriteContentType write json ContentType.
|
||||
func (r JSON) WriteContentType(w http.ResponseWriter) {
|
||||
writeContentType(w, jsonContentType)
|
||||
}
|
||||
|
||||
// MapJSON common map json struct.
|
||||
type MapJSON map[string]interface{}
|
||||
|
||||
// Render (MapJSON) writes data with json ContentType.
|
||||
func (m MapJSON) Render(w http.ResponseWriter) error {
|
||||
return writeJSON(w, m)
|
||||
}
|
||||
|
||||
// WriteContentType write json ContentType.
|
||||
func (m MapJSON) WriteContentType(w http.ResponseWriter) {
|
||||
writeContentType(w, jsonContentType)
|
||||
}
|
38
pkg/net/http/blademaster/render/protobuf.go
Normal file
38
pkg/net/http/blademaster/render/protobuf.go
Normal file
@ -0,0 +1,38 @@
|
||||
package render
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var pbContentType = []string{"application/x-protobuf"}
|
||||
|
||||
// Render (PB) writes data with protobuf ContentType.
|
||||
func (r PB) Render(w http.ResponseWriter) error {
|
||||
if r.TTL <= 0 {
|
||||
r.TTL = 1
|
||||
}
|
||||
return writePB(w, r)
|
||||
}
|
||||
|
||||
// WriteContentType write protobuf ContentType.
|
||||
func (r PB) WriteContentType(w http.ResponseWriter) {
|
||||
writeContentType(w, pbContentType)
|
||||
}
|
||||
|
||||
func writePB(w http.ResponseWriter, obj PB) (err error) {
|
||||
var pbBytes []byte
|
||||
writeContentType(w, pbContentType)
|
||||
|
||||
if pbBytes, err = proto.Marshal(&obj); err != nil {
|
||||
err = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = w.Write(pbBytes); err != nil {
|
||||
err = errors.WithStack(err)
|
||||
}
|
||||
return
|
||||
}
|
26
pkg/net/http/blademaster/render/redirect.go
Normal file
26
pkg/net/http/blademaster/render/redirect.go
Normal file
@ -0,0 +1,26 @@
|
||||
package render
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Redirect render for redirect to specified location.
|
||||
type Redirect struct {
|
||||
Code int
|
||||
Request *http.Request
|
||||
Location string
|
||||
}
|
||||
|
||||
// Render (Redirect) redirect to specified location.
|
||||
func (r Redirect) Render(w http.ResponseWriter) error {
|
||||
if (r.Code < 300 || r.Code > 308) && r.Code != 201 {
|
||||
return errors.Errorf("Cannot redirect with status code %d", r.Code)
|
||||
}
|
||||
http.Redirect(w, r.Request, r.Location, r.Code)
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteContentType noneContentType.
|
||||
func (r Redirect) WriteContentType(http.ResponseWriter) {}
|
30
pkg/net/http/blademaster/render/render.go
Normal file
30
pkg/net/http/blademaster/render/render.go
Normal file
@ -0,0 +1,30 @@
|
||||
package render
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Render http reponse render.
|
||||
type Render interface {
|
||||
// Render render it to http response writer.
|
||||
Render(http.ResponseWriter) error
|
||||
// WriteContentType write content-type to http response writer.
|
||||
WriteContentType(w http.ResponseWriter)
|
||||
}
|
||||
|
||||
var (
|
||||
_ Render = JSON{}
|
||||
_ Render = MapJSON{}
|
||||
_ Render = XML{}
|
||||
_ Render = String{}
|
||||
_ Render = Redirect{}
|
||||
_ Render = Data{}
|
||||
_ Render = PB{}
|
||||
)
|
||||
|
||||
func writeContentType(w http.ResponseWriter, value []string) {
|
||||
header := w.Header()
|
||||
if val := header["Content-Type"]; len(val) == 0 {
|
||||
header["Content-Type"] = value
|
||||
}
|
||||
}
|
89
pkg/net/http/blademaster/render/render.pb.go
Normal file
89
pkg/net/http/blademaster/render/render.pb.go
Normal file
@ -0,0 +1,89 @@
|
||||
// Code generated by protoc-gen-gogo. DO NOT EDIT.
|
||||
// source: pb.proto
|
||||
|
||||
/*
|
||||
Package render is a generated protocol buffer package.
|
||||
|
||||
It is generated from these files:
|
||||
pb.proto
|
||||
|
||||
It has these top-level messages:
|
||||
PB
|
||||
*/
|
||||
package render
|
||||
|
||||
import proto "github.com/gogo/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
import google_protobuf "github.com/gogo/protobuf/types"
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type PB struct {
|
||||
Code int64 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"`
|
||||
Message string `protobuf:"bytes,2,opt,name=Message,proto3" json:"Message,omitempty"`
|
||||
TTL uint64 `protobuf:"varint,3,opt,name=TTL,proto3" json:"TTL,omitempty"`
|
||||
Data *google_protobuf.Any `protobuf:"bytes,4,opt,name=Data" json:"Data,omitempty"`
|
||||
}
|
||||
|
||||
func (m *PB) Reset() { *m = PB{} }
|
||||
func (m *PB) String() string { return proto.CompactTextString(m) }
|
||||
func (*PB) ProtoMessage() {}
|
||||
func (*PB) Descriptor() ([]byte, []int) { return fileDescriptorPb, []int{0} }
|
||||
|
||||
func (m *PB) GetCode() int64 {
|
||||
if m != nil {
|
||||
return m.Code
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *PB) GetMessage() string {
|
||||
if m != nil {
|
||||
return m.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *PB) GetTTL() uint64 {
|
||||
if m != nil {
|
||||
return m.TTL
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *PB) GetData() *google_protobuf.Any {
|
||||
if m != nil {
|
||||
return m.Data
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*PB)(nil), "render.PB")
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("pb.proto", fileDescriptorPb) }
|
||||
|
||||
var fileDescriptorPb = []byte{
|
||||
// 154 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x28, 0x48, 0xd2, 0x2b,
|
||||
0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0x4a, 0xcd, 0x4b, 0x49, 0x2d, 0x92, 0x92, 0x4c, 0xcf,
|
||||
0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, 0x8b, 0x26, 0x95, 0xa6, 0xe9, 0x27, 0xe6, 0x55, 0x42, 0x94,
|
||||
0x28, 0xe5, 0x71, 0x31, 0x05, 0x38, 0x09, 0x09, 0x71, 0xb1, 0x38, 0xe7, 0xa7, 0xa4, 0x4a, 0x30,
|
||||
0x2a, 0x30, 0x6a, 0x30, 0x07, 0x81, 0xd9, 0x42, 0x12, 0x5c, 0xec, 0xbe, 0xa9, 0xc5, 0xc5, 0x89,
|
||||
0xe9, 0xa9, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x90, 0x00, 0x17, 0x73, 0x48,
|
||||
0x88, 0x8f, 0x04, 0xb3, 0x02, 0xa3, 0x06, 0x4b, 0x10, 0x88, 0x29, 0xa4, 0xc1, 0xc5, 0xe2, 0x92,
|
||||
0x58, 0x92, 0x28, 0xc1, 0xa2, 0xc0, 0xa8, 0xc1, 0x6d, 0x24, 0xa2, 0x07, 0xb1, 0x4f, 0x0f, 0x66,
|
||||
0x9f, 0x9e, 0x63, 0x5e, 0x65, 0x10, 0x58, 0x45, 0x12, 0x1b, 0x58, 0xcc, 0x18, 0x10, 0x00, 0x00,
|
||||
0xff, 0xff, 0x7a, 0x92, 0x16, 0x71, 0xa5, 0x00, 0x00, 0x00,
|
||||
}
|
14
pkg/net/http/blademaster/render/render.proto
Normal file
14
pkg/net/http/blademaster/render/render.proto
Normal file
@ -0,0 +1,14 @@
|
||||
// use under command to generate pb.pb.go
|
||||
// protoc --proto_path=.:$GOPATH/src/github.com/gogo/protobuf --gogo_out=Mgoogle/protobuf/any.proto=github.com/gogo/protobuf/types:. *.proto
|
||||
syntax = "proto3";
|
||||
package render;
|
||||
|
||||
import "google/protobuf/any.proto";
|
||||
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
|
||||
|
||||
message PB {
|
||||
int64 Code = 1;
|
||||
string Message = 2;
|
||||
uint64 TTL = 3;
|
||||
google.protobuf.Any Data = 4;
|
||||
}
|
40
pkg/net/http/blademaster/render/string.go
Normal file
40
pkg/net/http/blademaster/render/string.go
Normal file
@ -0,0 +1,40 @@
|
||||
package render
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var plainContentType = []string{"text/plain; charset=utf-8"}
|
||||
|
||||
// String common string struct.
|
||||
type String struct {
|
||||
Format string
|
||||
Data []interface{}
|
||||
}
|
||||
|
||||
// Render (String) writes data with custom ContentType.
|
||||
func (r String) Render(w http.ResponseWriter) error {
|
||||
return writeString(w, r.Format, r.Data)
|
||||
}
|
||||
|
||||
// WriteContentType writes string with text/plain ContentType.
|
||||
func (r String) WriteContentType(w http.ResponseWriter) {
|
||||
writeContentType(w, plainContentType)
|
||||
}
|
||||
|
||||
func writeString(w http.ResponseWriter, format string, data []interface{}) (err error) {
|
||||
writeContentType(w, plainContentType)
|
||||
if len(data) > 0 {
|
||||
_, err = fmt.Fprintf(w, format, data...)
|
||||
} else {
|
||||
_, err = io.WriteString(w, format)
|
||||
}
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
}
|
||||
return
|
||||
}
|
31
pkg/net/http/blademaster/render/xml.go
Normal file
31
pkg/net/http/blademaster/render/xml.go
Normal file
@ -0,0 +1,31 @@
|
||||
package render
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// XML common xml struct.
|
||||
type XML struct {
|
||||
Code int
|
||||
Message string
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
var xmlContentType = []string{"application/xml; charset=utf-8"}
|
||||
|
||||
// Render (XML) writes data with xml ContentType.
|
||||
func (r XML) Render(w http.ResponseWriter) (err error) {
|
||||
r.WriteContentType(w)
|
||||
if err = xml.NewEncoder(w).Encode(r.Data); err != nil {
|
||||
err = errors.WithStack(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// WriteContentType write xml ContentType.
|
||||
func (r XML) WriteContentType(w http.ResponseWriter) {
|
||||
writeContentType(w, xmlContentType)
|
||||
}
|
166
pkg/net/http/blademaster/routergroup.go
Normal file
166
pkg/net/http/blademaster/routergroup.go
Normal file
@ -0,0 +1,166 @@
|
||||
package blademaster
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// IRouter http router framework interface.
|
||||
type IRouter interface {
|
||||
IRoutes
|
||||
Group(string, ...HandlerFunc) *RouterGroup
|
||||
}
|
||||
|
||||
// IRoutes http router interface.
|
||||
type IRoutes interface {
|
||||
UseFunc(...HandlerFunc) IRoutes
|
||||
Use(...Handler) IRoutes
|
||||
|
||||
Handle(string, string, ...HandlerFunc) IRoutes
|
||||
HEAD(string, ...HandlerFunc) IRoutes
|
||||
GET(string, ...HandlerFunc) IRoutes
|
||||
POST(string, ...HandlerFunc) IRoutes
|
||||
PUT(string, ...HandlerFunc) IRoutes
|
||||
DELETE(string, ...HandlerFunc) IRoutes
|
||||
}
|
||||
|
||||
// RouterGroup is used internally to configure router, a RouterGroup is associated with a prefix
|
||||
// and an array of handlers (middleware).
|
||||
type RouterGroup struct {
|
||||
Handlers []HandlerFunc
|
||||
basePath string
|
||||
engine *Engine
|
||||
root bool
|
||||
baseConfig *MethodConfig
|
||||
}
|
||||
|
||||
var _ IRouter = &RouterGroup{}
|
||||
|
||||
// Use adds middleware to the group, see example code in doc.
|
||||
func (group *RouterGroup) Use(middleware ...Handler) IRoutes {
|
||||
for _, m := range middleware {
|
||||
group.Handlers = append(group.Handlers, m.ServeHTTP)
|
||||
}
|
||||
return group.returnObj()
|
||||
}
|
||||
|
||||
// UseFunc adds middleware to the group, see example code in doc.
|
||||
func (group *RouterGroup) UseFunc(middleware ...HandlerFunc) IRoutes {
|
||||
group.Handlers = append(group.Handlers, middleware...)
|
||||
return group.returnObj()
|
||||
}
|
||||
|
||||
// Group creates a new router group. You should add all the routes that have common middlwares or the same path prefix.
|
||||
// For example, all the routes that use a common middlware for authorization could be grouped.
|
||||
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup {
|
||||
return &RouterGroup{
|
||||
Handlers: group.combineHandlers(handlers),
|
||||
basePath: group.calculateAbsolutePath(relativePath),
|
||||
engine: group.engine,
|
||||
root: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SetMethodConfig is used to set config on specified method
|
||||
func (group *RouterGroup) SetMethodConfig(config *MethodConfig) *RouterGroup {
|
||||
group.baseConfig = config
|
||||
return group
|
||||
}
|
||||
|
||||
// BasePath router group base path.
|
||||
func (group *RouterGroup) BasePath() string {
|
||||
return group.basePath
|
||||
}
|
||||
|
||||
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
|
||||
absolutePath := group.calculateAbsolutePath(relativePath)
|
||||
injections := group.injections(relativePath)
|
||||
handlers = group.combineHandlers(injections, handlers)
|
||||
group.engine.addRoute(httpMethod, absolutePath, handlers...)
|
||||
if group.baseConfig != nil {
|
||||
group.engine.SetMethodConfig(absolutePath, group.baseConfig)
|
||||
}
|
||||
return group.returnObj()
|
||||
}
|
||||
|
||||
// Handle registers a new request handle and middleware with the given path and method.
|
||||
// The last handler should be the real handler, the other ones should be middleware that can and should be shared among different routes.
|
||||
// See the example code in doc.
|
||||
//
|
||||
// For HEAD, GET, POST, PUT, and DELETE requests the respective shortcut
|
||||
// functions can be used.
|
||||
//
|
||||
// This function is intended for bulk loading and to allow the usage of less
|
||||
// frequently used, non-standardized or custom methods (e.g. for internal
|
||||
// communication with a proxy).
|
||||
func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
|
||||
if matches, err := regexp.MatchString("^[A-Z]+$", httpMethod); !matches || err != nil {
|
||||
panic("http method " + httpMethod + " is not valid")
|
||||
}
|
||||
return group.handle(httpMethod, relativePath, handlers...)
|
||||
}
|
||||
|
||||
// HEAD is a shortcut for router.Handle("HEAD", path, handle).
|
||||
func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRoutes {
|
||||
return group.handle("HEAD", relativePath, handlers...)
|
||||
}
|
||||
|
||||
// GET is a shortcut for router.Handle("GET", path, handle).
|
||||
func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes {
|
||||
return group.handle("GET", relativePath, handlers...)
|
||||
}
|
||||
|
||||
// POST is a shortcut for router.Handle("POST", path, handle).
|
||||
func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) IRoutes {
|
||||
return group.handle("POST", relativePath, handlers...)
|
||||
}
|
||||
|
||||
// PUT is a shortcut for router.Handle("PUT", path, handle).
|
||||
func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) IRoutes {
|
||||
return group.handle("PUT", relativePath, handlers...)
|
||||
}
|
||||
|
||||
// DELETE is a shortcut for router.Handle("DELETE", path, handle).
|
||||
func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) IRoutes {
|
||||
return group.handle("DELETE", relativePath, handlers...)
|
||||
}
|
||||
|
||||
func (group *RouterGroup) combineHandlers(handlerGroups ...[]HandlerFunc) []HandlerFunc {
|
||||
finalSize := len(group.Handlers)
|
||||
for _, handlers := range handlerGroups {
|
||||
finalSize += len(handlers)
|
||||
}
|
||||
if finalSize >= int(_abortIndex) {
|
||||
panic("too many handlers")
|
||||
}
|
||||
mergedHandlers := make([]HandlerFunc, finalSize)
|
||||
copy(mergedHandlers, group.Handlers)
|
||||
position := len(group.Handlers)
|
||||
for _, handlers := range handlerGroups {
|
||||
copy(mergedHandlers[position:], handlers)
|
||||
position += len(handlers)
|
||||
}
|
||||
return mergedHandlers
|
||||
}
|
||||
|
||||
func (group *RouterGroup) calculateAbsolutePath(relativePath string) string {
|
||||
return joinPaths(group.basePath, relativePath)
|
||||
}
|
||||
|
||||
func (group *RouterGroup) returnObj() IRoutes {
|
||||
if group.root {
|
||||
return group.engine
|
||||
}
|
||||
return group
|
||||
}
|
||||
|
||||
// injections is
|
||||
func (group *RouterGroup) injections(relativePath string) []HandlerFunc {
|
||||
absPath := group.calculateAbsolutePath(relativePath)
|
||||
for _, injection := range group.engine.injections {
|
||||
if !injection.pattern.MatchString(absPath) {
|
||||
continue
|
||||
}
|
||||
return injection.handlers
|
||||
}
|
||||
return nil
|
||||
}
|
445
pkg/net/http/blademaster/server.go
Normal file
445
pkg/net/http/blademaster/server.go
Normal file
@ -0,0 +1,445 @@
|
||||
package blademaster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/Kratos/pkg/conf/dsn"
|
||||
"github.com/bilibili/Kratos/pkg/log"
|
||||
"github.com/bilibili/Kratos/pkg/net/ip"
|
||||
"github.com/bilibili/Kratos/pkg/net/metadata"
|
||||
"github.com/bilibili/Kratos/pkg/stat"
|
||||
xtime "github.com/bilibili/Kratos/pkg/time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxMemory = 32 << 20 // 32 MB
|
||||
)
|
||||
|
||||
var (
|
||||
_ IRouter = &Engine{}
|
||||
stats = stat.HTTPServer
|
||||
|
||||
_httpDSN string
|
||||
)
|
||||
|
||||
func init() {
|
||||
addFlag(flag.CommandLine)
|
||||
}
|
||||
|
||||
func addFlag(fs *flag.FlagSet) {
|
||||
v := os.Getenv("HTTP")
|
||||
if v == "" {
|
||||
v = "tcp://0.0.0.0:8000/?timeout=1s"
|
||||
}
|
||||
fs.StringVar(&_httpDSN, "http", v, "listen http dsn, or use HTTP env variable.")
|
||||
}
|
||||
|
||||
func parseDSN(rawdsn string) *ServerConfig {
|
||||
conf := new(ServerConfig)
|
||||
d, err := dsn.Parse(rawdsn)
|
||||
if err != nil {
|
||||
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn))
|
||||
}
|
||||
if _, err = d.Bind(conf); err != nil {
|
||||
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn))
|
||||
}
|
||||
return conf
|
||||
}
|
||||
|
||||
// Handler responds to an HTTP request.
|
||||
type Handler interface {
|
||||
ServeHTTP(c *Context)
|
||||
}
|
||||
|
||||
// HandlerFunc http request handler function.
|
||||
type HandlerFunc func(*Context)
|
||||
|
||||
// ServeHTTP calls f(ctx).
|
||||
func (f HandlerFunc) ServeHTTP(c *Context) {
|
||||
f(c)
|
||||
}
|
||||
|
||||
// ServerConfig is the bm server config model
|
||||
type ServerConfig struct {
|
||||
Network string `dsn:"network"`
|
||||
// FIXME: rename to Address
|
||||
Addr string `dsn:"address"`
|
||||
Timeout xtime.Duration `dsn:"query.timeout"`
|
||||
ReadTimeout xtime.Duration `dsn:"query.readTimeout"`
|
||||
WriteTimeout xtime.Duration `dsn:"query.writeTimeout"`
|
||||
}
|
||||
|
||||
// MethodConfig is
|
||||
type MethodConfig struct {
|
||||
Timeout xtime.Duration
|
||||
}
|
||||
|
||||
// Start listen and serve bm engine by given DSN.
|
||||
func (engine *Engine) Start() error {
|
||||
conf := engine.conf
|
||||
l, err := net.Listen(conf.Network, conf.Addr)
|
||||
if err != nil {
|
||||
errors.Wrapf(err, "blademaster: listen tcp: %s", conf.Addr)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("blademaster: start http listen addr: %s", conf.Addr)
|
||||
server := &http.Server{
|
||||
ReadTimeout: time.Duration(conf.ReadTimeout),
|
||||
WriteTimeout: time.Duration(conf.WriteTimeout),
|
||||
}
|
||||
go func() {
|
||||
if err := engine.RunServer(server, l); err != nil {
|
||||
if errors.Cause(err) == http.ErrServerClosed {
|
||||
log.Info("blademaster: server closed")
|
||||
return
|
||||
}
|
||||
panic(errors.Wrapf(err, "blademaster: engine.ListenServer(%+v, %+v)", server, l))
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Engine is the framework's instance, it contains the muxer, middleware and configuration settings.
|
||||
// Create an instance of Engine, by using New() or Default()
|
||||
type Engine struct {
|
||||
RouterGroup
|
||||
|
||||
lock sync.RWMutex
|
||||
conf *ServerConfig
|
||||
|
||||
address string
|
||||
|
||||
mux *http.ServeMux // http mux router
|
||||
server atomic.Value // store *http.Server
|
||||
metastore map[string]map[string]interface{} // metastore is the path as key and the metadata of this path as value, it export via /metadata
|
||||
|
||||
pcLock sync.RWMutex
|
||||
methodConfigs map[string]*MethodConfig
|
||||
|
||||
injections []injection
|
||||
}
|
||||
|
||||
type injection struct {
|
||||
pattern *regexp.Regexp
|
||||
handlers []HandlerFunc
|
||||
}
|
||||
|
||||
// New returns a new blank Engine instance without any middleware attached.
|
||||
//
|
||||
// Deprecated: please use NewServer.
|
||||
func New() *Engine {
|
||||
engine := &Engine{
|
||||
RouterGroup: RouterGroup{
|
||||
Handlers: nil,
|
||||
basePath: "/",
|
||||
root: true,
|
||||
},
|
||||
address: ip.InternalIP(),
|
||||
conf: &ServerConfig{
|
||||
Timeout: xtime.Duration(time.Second),
|
||||
},
|
||||
mux: http.NewServeMux(),
|
||||
metastore: make(map[string]map[string]interface{}),
|
||||
methodConfigs: make(map[string]*MethodConfig),
|
||||
injections: make([]injection, 0),
|
||||
}
|
||||
engine.RouterGroup.engine = engine
|
||||
// NOTE add prometheus monitor location
|
||||
engine.addRoute("GET", "/metrics", monitor())
|
||||
engine.addRoute("GET", "/metadata", engine.metadata())
|
||||
startPerf()
|
||||
return engine
|
||||
}
|
||||
|
||||
// NewServer returns a new blank Engine instance without any middleware attached.
|
||||
func NewServer(conf *ServerConfig) *Engine {
|
||||
if conf == nil {
|
||||
if !flag.Parsed() {
|
||||
fmt.Fprint(os.Stderr, "[blademaster] please call flag.Parse() before Init warden server, some configure may not effect.\n")
|
||||
}
|
||||
conf = parseDSN(_httpDSN)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "[blademaster] config will be deprecated, argument will be ignored. please use -http flag or HTTP env to configure http server.\n")
|
||||
}
|
||||
|
||||
engine := &Engine{
|
||||
RouterGroup: RouterGroup{
|
||||
Handlers: nil,
|
||||
basePath: "/",
|
||||
root: true,
|
||||
},
|
||||
address: ip.InternalIP(),
|
||||
mux: http.NewServeMux(),
|
||||
metastore: make(map[string]map[string]interface{}),
|
||||
methodConfigs: make(map[string]*MethodConfig),
|
||||
}
|
||||
if err := engine.SetConfig(conf); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
engine.RouterGroup.engine = engine
|
||||
// NOTE add prometheus monitor location
|
||||
engine.addRoute("GET", "/metrics", monitor())
|
||||
engine.addRoute("GET", "/metadata", engine.metadata())
|
||||
startPerf()
|
||||
return engine
|
||||
}
|
||||
|
||||
// SetMethodConfig is used to set config on specified path
|
||||
func (engine *Engine) SetMethodConfig(path string, mc *MethodConfig) {
|
||||
engine.pcLock.Lock()
|
||||
engine.methodConfigs[path] = mc
|
||||
engine.pcLock.Unlock()
|
||||
}
|
||||
|
||||
// DefaultServer returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
|
||||
func DefaultServer(conf *ServerConfig) *Engine {
|
||||
engine := NewServer(conf)
|
||||
engine.Use(Recovery(), Trace(), Logger())
|
||||
return engine
|
||||
}
|
||||
|
||||
// Default returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
|
||||
//
|
||||
// Deprecated: please use DefaultServer.
|
||||
func Default() *Engine {
|
||||
engine := New()
|
||||
engine.Use(Recovery(), Trace(), Logger())
|
||||
return engine
|
||||
}
|
||||
|
||||
func (engine *Engine) addRoute(method, path string, handlers ...HandlerFunc) {
|
||||
if path[0] != '/' {
|
||||
panic("blademaster: path must begin with '/'")
|
||||
}
|
||||
if method == "" {
|
||||
panic("blademaster: HTTP method can not be empty")
|
||||
}
|
||||
if len(handlers) == 0 {
|
||||
panic("blademaster: there must be at least one handler")
|
||||
}
|
||||
if _, ok := engine.metastore[path]; !ok {
|
||||
engine.metastore[path] = make(map[string]interface{})
|
||||
}
|
||||
engine.metastore[path]["method"] = method
|
||||
engine.mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) {
|
||||
c := &Context{
|
||||
Context: nil,
|
||||
engine: engine,
|
||||
index: -1,
|
||||
handlers: nil,
|
||||
Keys: nil,
|
||||
method: "",
|
||||
Error: nil,
|
||||
}
|
||||
|
||||
c.Request = req
|
||||
c.Writer = w
|
||||
c.handlers = handlers
|
||||
c.method = method
|
||||
|
||||
engine.handleContext(c)
|
||||
})
|
||||
}
|
||||
|
||||
// SetConfig is used to set the engine configuration.
|
||||
// Only the valid config will be loaded.
|
||||
func (engine *Engine) SetConfig(conf *ServerConfig) (err error) {
|
||||
if conf.Timeout <= 0 {
|
||||
return errors.New("blademaster: config timeout must greater than 0")
|
||||
}
|
||||
if conf.Network == "" {
|
||||
conf.Network = "tcp"
|
||||
}
|
||||
engine.lock.Lock()
|
||||
engine.conf = conf
|
||||
engine.lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (engine *Engine) methodConfig(path string) *MethodConfig {
|
||||
engine.pcLock.RLock()
|
||||
mc := engine.methodConfigs[path]
|
||||
engine.pcLock.RUnlock()
|
||||
return mc
|
||||
}
|
||||
|
||||
func (engine *Engine) handleContext(c *Context) {
|
||||
var cancel func()
|
||||
req := c.Request
|
||||
ctype := req.Header.Get("Content-Type")
|
||||
switch {
|
||||
case strings.Contains(ctype, "multipart/form-data"):
|
||||
req.ParseMultipartForm(defaultMaxMemory)
|
||||
default:
|
||||
req.ParseForm()
|
||||
}
|
||||
// get derived timeout from http request header,
|
||||
// compare with the engine configured,
|
||||
// and use the minimum one
|
||||
engine.lock.RLock()
|
||||
tm := time.Duration(engine.conf.Timeout)
|
||||
engine.lock.RUnlock()
|
||||
// the method config is preferred
|
||||
if pc := engine.methodConfig(c.Request.URL.Path); pc != nil {
|
||||
tm = time.Duration(pc.Timeout)
|
||||
}
|
||||
if ctm := timeout(req); ctm > 0 && tm > ctm {
|
||||
tm = ctm
|
||||
}
|
||||
md := metadata.MD{
|
||||
metadata.Color: color(req),
|
||||
metadata.RemoteIP: remoteIP(req),
|
||||
metadata.RemotePort: remotePort(req),
|
||||
metadata.Caller: caller(req),
|
||||
metadata.Mirror: mirror(req),
|
||||
}
|
||||
ctx := metadata.NewContext(context.Background(), md)
|
||||
if tm > 0 {
|
||||
c.Context, cancel = context.WithTimeout(ctx, tm)
|
||||
} else {
|
||||
c.Context, cancel = context.WithCancel(ctx)
|
||||
}
|
||||
defer cancel()
|
||||
c.Next()
|
||||
}
|
||||
|
||||
// Router return a http.Handler for using http.ListenAndServe() directly.
|
||||
func (engine *Engine) Router() http.Handler {
|
||||
return engine.mux
|
||||
}
|
||||
|
||||
// Server is used to load stored http server.
|
||||
func (engine *Engine) Server() *http.Server {
|
||||
s, ok := engine.server.Load().(*http.Server)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Shutdown the http server without interrupting active connections.
|
||||
func (engine *Engine) Shutdown(ctx context.Context) error {
|
||||
server := engine.Server()
|
||||
if server == nil {
|
||||
return errors.New("blademaster: no server")
|
||||
}
|
||||
return errors.WithStack(server.Shutdown(ctx))
|
||||
}
|
||||
|
||||
// UseFunc attachs a global middleware to the router. ie. the middleware attached though UseFunc() will be
|
||||
// included in the handlers chain for every single request. Even 404, 405, static files...
|
||||
// For example, this is the right place for a logger or error management middleware.
|
||||
func (engine *Engine) UseFunc(middleware ...HandlerFunc) IRoutes {
|
||||
engine.RouterGroup.UseFunc(middleware...)
|
||||
return engine
|
||||
}
|
||||
|
||||
// Use attachs a global middleware to the router. ie. the middleware attached though Use() will be
|
||||
// included in the handlers chain for every single request. Even 404, 405, static files...
|
||||
// For example, this is the right place for a logger or error management middleware.
|
||||
func (engine *Engine) Use(middleware ...Handler) IRoutes {
|
||||
engine.RouterGroup.Use(middleware...)
|
||||
return engine
|
||||
}
|
||||
|
||||
// Ping is used to set the general HTTP ping handler.
|
||||
func (engine *Engine) Ping(handler HandlerFunc) {
|
||||
engine.GET("/monitor/ping", handler)
|
||||
}
|
||||
|
||||
// Register is used to export metadata to discovery.
|
||||
func (engine *Engine) Register(handler HandlerFunc) {
|
||||
engine.GET("/register", handler)
|
||||
}
|
||||
|
||||
// Run attaches the router to a http.Server and starts listening and serving HTTP requests.
|
||||
// It is a shortcut for http.ListenAndServe(addr, router)
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) Run(addr ...string) (err error) {
|
||||
address := resolveAddress(addr)
|
||||
server := &http.Server{
|
||||
Addr: address,
|
||||
Handler: engine.mux,
|
||||
}
|
||||
engine.server.Store(server)
|
||||
if err = server.ListenAndServe(); err != nil {
|
||||
err = errors.Wrapf(err, "addrs: %v", addr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests.
|
||||
// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router)
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) {
|
||||
server := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: engine.mux,
|
||||
}
|
||||
engine.server.Store(server)
|
||||
if err = server.ListenAndServeTLS(certFile, keyFile); err != nil {
|
||||
err = errors.Wrapf(err, "tls: %s/%s:%s", addr, certFile, keyFile)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// RunUnix attaches the router to a http.Server and starts listening and serving HTTP requests
|
||||
// through the specified unix socket (ie. a file).
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) RunUnix(file string) (err error) {
|
||||
os.Remove(file)
|
||||
listener, err := net.Listen("unix", file)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "unix: %s", file)
|
||||
return
|
||||
}
|
||||
defer listener.Close()
|
||||
server := &http.Server{
|
||||
Handler: engine.mux,
|
||||
}
|
||||
engine.server.Store(server)
|
||||
if err = server.Serve(listener); err != nil {
|
||||
err = errors.Wrapf(err, "unix: %s", file)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// RunServer will serve and start listening HTTP requests by given server and listener.
|
||||
// Note: this method will block the calling goroutine indefinitely unless an error happens.
|
||||
func (engine *Engine) RunServer(server *http.Server, l net.Listener) (err error) {
|
||||
server.Handler = engine.mux
|
||||
engine.server.Store(server)
|
||||
if err = server.Serve(l); err != nil {
|
||||
err = errors.Wrapf(err, "listen server: %+v/%+v", server, l)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (engine *Engine) metadata() HandlerFunc {
|
||||
return func(c *Context) {
|
||||
c.JSON(engine.metastore, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Inject is
|
||||
func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) {
|
||||
engine.injections = append(engine.injections, injection{
|
||||
pattern: regexp.MustCompile(pattern),
|
||||
handlers: handlers,
|
||||
})
|
||||
}
|
@ -4,12 +4,43 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"strconv"
|
||||
|
||||
"github.com/bilibili/Kratos/pkg/net/metadata"
|
||||
"github.com/bilibili/Kratos/pkg/net/trace"
|
||||
)
|
||||
|
||||
const _defaultComponentName = "net/http"
|
||||
|
||||
// Trace is trace middleware
|
||||
func Trace() HandlerFunc {
|
||||
return func(c *Context) {
|
||||
// handle http request
|
||||
// get derived trace from http request header
|
||||
t, err := trace.Extract(trace.HTTPFormat, c.Request.Header)
|
||||
if err != nil {
|
||||
var opts []trace.Option
|
||||
if ok, _ := strconv.ParseBool(trace.KratosTraceDebug); ok {
|
||||
opts = append(opts, trace.EnableDebug())
|
||||
}
|
||||
t = trace.New(c.Request.URL.Path, opts...)
|
||||
}
|
||||
t.SetTitle(c.Request.URL.Path)
|
||||
t.SetTag(trace.String(trace.TagComponent, _defaultComponentName))
|
||||
t.SetTag(trace.String(trace.TagHTTPMethod, c.Request.Method))
|
||||
t.SetTag(trace.String(trace.TagHTTPURL, c.Request.URL.String()))
|
||||
t.SetTag(trace.String(trace.TagSpanKind, "server"))
|
||||
// business tag
|
||||
t.SetTag(trace.String("caller", metadata.String(c.Context, metadata.Caller)))
|
||||
// export trace id to user.
|
||||
// TODO(zhoujiahui): trace package should be updated
|
||||
// c.Writer.Header().Set(trace.KratosTraceID, t.TraceID())
|
||||
c.Context = trace.NewContext(c.Context, t)
|
||||
c.Next()
|
||||
t.Finish(&c.Error)
|
||||
}
|
||||
}
|
||||
|
||||
type closeTracker struct {
|
||||
io.ReadCloser
|
||||
tr trace.Trace
|
||||
|
42
pkg/net/http/blademaster/utils.go
Normal file
42
pkg/net/http/blademaster/utils.go
Normal file
@ -0,0 +1,42 @@
|
||||
package blademaster
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
)
|
||||
|
||||
func lastChar(str string) uint8 {
|
||||
if str == "" {
|
||||
panic("The length of the string can't be 0")
|
||||
}
|
||||
return str[len(str)-1]
|
||||
}
|
||||
|
||||
func joinPaths(absolutePath, relativePath string) string {
|
||||
if relativePath == "" {
|
||||
return absolutePath
|
||||
}
|
||||
|
||||
finalPath := path.Join(absolutePath, relativePath)
|
||||
appendSlash := lastChar(relativePath) == '/' && lastChar(finalPath) != '/'
|
||||
if appendSlash {
|
||||
return finalPath + "/"
|
||||
}
|
||||
return finalPath
|
||||
}
|
||||
|
||||
func resolveAddress(addr []string) string {
|
||||
switch len(addr) {
|
||||
case 0:
|
||||
if port := os.Getenv("PORT"); port != "" {
|
||||
//debugPrint("Environment variable PORT=\"%s\"", port)
|
||||
return ":" + port
|
||||
}
|
||||
//debugPrint("Environment variable PORT is undefined. Using port :8080 by default")
|
||||
return ":8080"
|
||||
case 1:
|
||||
return addr[0]
|
||||
default:
|
||||
panic("too much parameters")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user