1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-24 03:46:37 +02:00

blademaster initial (#6)

This commit is contained in:
realityone 2019-04-11 15:07:22 +08:00 committed by Felix Hao
parent 1efe0a084e
commit 96d32e866a
34 changed files with 3107 additions and 3 deletions

View File

@ -0,0 +1,85 @@
package binding
import (
"net/http"
"strings"
"gopkg.in/go-playground/validator.v9"
)
// MIME
const (
MIMEJSON = "application/json"
MIMEHTML = "text/html"
MIMEXML = "application/xml"
MIMEXML2 = "text/xml"
MIMEPlain = "text/plain"
MIMEPOSTForm = "application/x-www-form-urlencoded"
MIMEMultipartPOSTForm = "multipart/form-data"
)
// Binding http binding request interface.
type Binding interface {
Name() string
Bind(*http.Request, interface{}) error
}
// StructValidator http validator interface.
type StructValidator interface {
// ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right.
// If the received type is not a struct, any validation should be skipped and nil must be returned.
// If the received type is a struct or pointer to a struct, the validation should be performed.
// If the struct is not valid or the validation itself fails, a descriptive error should be returned.
// Otherwise nil must be returned.
ValidateStruct(interface{}) error
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
// NOTE: if the key already exists, the previous validation function will be replaced.
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
RegisterValidation(string, validator.Func) error
}
// Validator default validator.
var Validator StructValidator = &defaultValidator{}
// Binding
var (
JSON = jsonBinding{}
XML = xmlBinding{}
Form = formBinding{}
Query = queryBinding{}
FormPost = formPostBinding{}
FormMultipart = formMultipartBinding{}
)
// Default get by binding type by method and contexttype.
func Default(method, contentType string) Binding {
if method == "GET" {
return Form
}
contentType = stripContentTypeParam(contentType)
switch contentType {
case MIMEJSON:
return JSON
case MIMEXML, MIMEXML2:
return XML
default: //case MIMEPOSTForm, MIMEMultipartPOSTForm:
return Form
}
}
func validate(obj interface{}) error {
if Validator == nil {
return nil
}
return Validator.ValidateStruct(obj)
}
func stripContentTypeParam(contentType string) string {
i := strings.Index(contentType, ";")
if i != -1 {
contentType = contentType[:i]
}
return contentType
}

View File

@ -0,0 +1,342 @@
package binding
import (
"bytes"
"mime/multipart"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
type FooStruct struct {
Foo string `msgpack:"foo" json:"foo" form:"foo" xml:"foo" validate:"required"`
}
type FooBarStruct struct {
FooStruct
Bar string `msgpack:"bar" json:"bar" form:"bar" xml:"bar" validate:"required"`
Slice []string `form:"slice" validate:"max=10"`
}
type ComplexDefaultStruct struct {
Int int `form:"int" default:"999"`
String string `form:"string" default:"default-string"`
Bool bool `form:"bool" default:"false"`
Int64Slice []int64 `form:"int64_slice,split" default:"1,2,3,4"`
Int8Slice []int8 `form:"int8_slice,split" default:"1,2,3,4"`
}
type Int8SliceStruct struct {
State []int8 `form:"state,split"`
}
type Int64SliceStruct struct {
State []int64 `form:"state,split"`
}
type StringSliceStruct struct {
State []string `form:"state,split"`
}
func TestBindingDefault(t *testing.T) {
assert.Equal(t, Default("GET", ""), Form)
assert.Equal(t, Default("GET", MIMEJSON), Form)
assert.Equal(t, Default("GET", MIMEJSON+"; charset=utf-8"), Form)
assert.Equal(t, Default("POST", MIMEJSON), JSON)
assert.Equal(t, Default("PUT", MIMEJSON), JSON)
assert.Equal(t, Default("POST", MIMEJSON+"; charset=utf-8"), JSON)
assert.Equal(t, Default("PUT", MIMEJSON+"; charset=utf-8"), JSON)
assert.Equal(t, Default("POST", MIMEXML), XML)
assert.Equal(t, Default("PUT", MIMEXML2), XML)
assert.Equal(t, Default("POST", MIMEPOSTForm), Form)
assert.Equal(t, Default("PUT", MIMEPOSTForm), Form)
assert.Equal(t, Default("POST", MIMEPOSTForm+"; charset=utf-8"), Form)
assert.Equal(t, Default("PUT", MIMEPOSTForm+"; charset=utf-8"), Form)
assert.Equal(t, Default("POST", MIMEMultipartPOSTForm), Form)
assert.Equal(t, Default("PUT", MIMEMultipartPOSTForm), Form)
}
func TestStripContentType(t *testing.T) {
c1 := "application/vnd.mozilla.xul+xml"
c2 := "application/vnd.mozilla.xul+xml; charset=utf-8"
assert.Equal(t, stripContentTypeParam(c1), c1)
assert.Equal(t, stripContentTypeParam(c2), "application/vnd.mozilla.xul+xml")
}
func TestBindInt8Form(t *testing.T) {
params := "state=1,2,3"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(Int8SliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []int8{1, 2, 3}, q.State)
params = "state=1,2,3,256"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int8SliceStruct)
assert.Error(t, Form.Bind(req, q))
params = "state="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int8SliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Len(t, q.State, 0)
params = "state=1,,2"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int8SliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.EqualValues(t, []int8{1, 2}, q.State)
}
func TestBindInt64Form(t *testing.T) {
params := "state=1,2,3"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(Int64SliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []int64{1, 2, 3}, q.State)
params = "state="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(Int64SliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Len(t, q.State, 0)
}
func TestBindStringForm(t *testing.T) {
params := "state=1,2,3"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(StringSliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []string{"1", "2", "3"}, q.State)
params = "state="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(StringSliceStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Len(t, q.State, 0)
params = "state=p,,p"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(StringSliceStruct)
Form.Bind(req, q)
assert.EqualValues(t, []string{"p", "p"}, q.State)
}
func TestBindingJSON(t *testing.T) {
testBodyBinding(t,
JSON, "json",
"/", "/",
`{"foo": "bar"}`, `{"bar": "foo"}`)
}
func TestBindingForm(t *testing.T) {
testFormBinding(t, "POST",
"/", "/",
"foo=bar&bar=foo&slice=a&slice=b", "bar2=foo")
}
func TestBindingForm2(t *testing.T) {
testFormBinding(t, "GET",
"/?foo=bar&bar=foo", "/?bar2=foo",
"", "")
}
func TestBindingQuery(t *testing.T) {
testQueryBinding(t, "POST",
"/?foo=bar&bar=foo", "/",
"foo=unused", "bar2=foo")
}
func TestBindingQuery2(t *testing.T) {
testQueryBinding(t, "GET",
"/?foo=bar&bar=foo", "/?bar2=foo",
"foo=unused", "")
}
func TestBindingXML(t *testing.T) {
testBodyBinding(t,
XML, "xml",
"/", "/",
"<map><foo>bar</foo></map>", "<map><bar>foo</bar></map>")
}
func createFormPostRequest() *http.Request {
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", bytes.NewBufferString("foo=bar&bar=foo"))
req.Header.Set("Content-Type", MIMEPOSTForm)
return req
}
func createFormMultipartRequest() *http.Request {
boundary := "--testboundary"
body := new(bytes.Buffer)
mw := multipart.NewWriter(body)
defer mw.Close()
mw.SetBoundary(boundary)
mw.WriteField("foo", "bar")
mw.WriteField("bar", "foo")
req, _ := http.NewRequest("POST", "/?foo=getfoo&bar=getbar", body)
req.Header.Set("Content-Type", MIMEMultipartPOSTForm+"; boundary="+boundary)
return req
}
func TestBindingFormPost(t *testing.T) {
req := createFormPostRequest()
var obj FooBarStruct
FormPost.Bind(req, &obj)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
}
func TestBindingFormMultipart(t *testing.T) {
req := createFormMultipartRequest()
var obj FooBarStruct
FormMultipart.Bind(req, &obj)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
}
func TestValidationFails(t *testing.T) {
var obj FooStruct
req := requestWithBody("POST", "/", `{"bar": "foo"}`)
err := JSON.Bind(req, &obj)
assert.Error(t, err)
}
func TestValidationDisabled(t *testing.T) {
backup := Validator
Validator = nil
defer func() { Validator = backup }()
var obj FooStruct
req := requestWithBody("POST", "/", `{"bar": "foo"}`)
err := JSON.Bind(req, &obj)
assert.NoError(t, err)
}
func TestExistsSucceeds(t *testing.T) {
type HogeStruct struct {
Hoge *int `json:"hoge" binding:"exists"`
}
var obj HogeStruct
req := requestWithBody("POST", "/", `{"hoge": 0}`)
err := JSON.Bind(req, &obj)
assert.NoError(t, err)
}
func TestFormDefaultValue(t *testing.T) {
params := "int=333&string=hello&bool=true&int64_slice=5,6,7,8&int8_slice=5,6,7,8"
req, _ := http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q := new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 333, q.Int)
assert.Equal(t, "hello", q.String)
assert.Equal(t, true, q.Bool)
assert.EqualValues(t, []int64{5, 6, 7, 8}, q.Int64Slice)
assert.EqualValues(t, []int8{5, 6, 7, 8}, q.Int8Slice)
params = "string=hello&bool=false"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 999, q.Int)
assert.Equal(t, "hello", q.String)
assert.Equal(t, false, q.Bool)
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
params = "strings=hello"
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 999, q.Int)
assert.Equal(t, "default-string", q.String)
assert.Equal(t, false, q.Bool)
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
params = "int=&string=&bool=true&int64_slice=&int8_slice="
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
q = new(ComplexDefaultStruct)
assert.NoError(t, Form.Bind(req, q))
assert.Equal(t, 999, q.Int)
assert.Equal(t, "default-string", q.String)
assert.Equal(t, true, q.Bool)
assert.EqualValues(t, []int64{1, 2, 3, 4}, q.Int64Slice)
assert.EqualValues(t, []int8{1, 2, 3, 4}, q.Int8Slice)
}
func testFormBinding(t *testing.T, method, path, badPath, body, badBody string) {
b := Form
assert.Equal(t, b.Name(), "form")
obj := FooBarStruct{}
req := requestWithBody(method, path, body)
if method == "POST" {
req.Header.Add("Content-Type", MIMEPOSTForm)
}
err := b.Bind(req, &obj)
assert.NoError(t, err)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
obj = FooBarStruct{}
req = requestWithBody(method, badPath, badBody)
err = JSON.Bind(req, &obj)
assert.Error(t, err)
}
func testQueryBinding(t *testing.T, method, path, badPath, body, badBody string) {
b := Query
assert.Equal(t, b.Name(), "query")
obj := FooBarStruct{}
req := requestWithBody(method, path, body)
if method == "POST" {
req.Header.Add("Content-Type", MIMEPOSTForm)
}
err := b.Bind(req, &obj)
assert.NoError(t, err)
assert.Equal(t, obj.Foo, "bar")
assert.Equal(t, obj.Bar, "foo")
}
func testBodyBinding(t *testing.T, b Binding, name, path, badPath, body, badBody string) {
assert.Equal(t, b.Name(), name)
obj := FooStruct{}
req := requestWithBody("POST", path, body)
err := b.Bind(req, &obj)
assert.NoError(t, err)
assert.Equal(t, obj.Foo, "bar")
obj = FooStruct{}
req = requestWithBody("POST", badPath, badBody)
err = JSON.Bind(req, &obj)
assert.Error(t, err)
}
func requestWithBody(method, path, body string) (req *http.Request) {
req, _ = http.NewRequest(method, path, bytes.NewBufferString(body))
return
}
func BenchmarkBindingForm(b *testing.B) {
req := requestWithBody("POST", "/", "foo=bar&bar=foo&slice=a&slice=b&slice=c&slice=w")
req.Header.Add("Content-Type", MIMEPOSTForm)
f := Form
for i := 0; i < b.N; i++ {
obj := FooBarStruct{}
f.Bind(req, &obj)
}
}

View File

@ -0,0 +1,45 @@
package binding
import (
"reflect"
"sync"
"gopkg.in/go-playground/validator.v9"
)
type defaultValidator struct {
once sync.Once
validate *validator.Validate
}
var _ StructValidator = &defaultValidator{}
func (v *defaultValidator) ValidateStruct(obj interface{}) error {
if kindOfData(obj) == reflect.Struct {
v.lazyinit()
if err := v.validate.Struct(obj); err != nil {
return err
}
}
return nil
}
func (v *defaultValidator) RegisterValidation(key string, fn validator.Func) error {
v.lazyinit()
return v.validate.RegisterValidation(key, fn)
}
func (v *defaultValidator) lazyinit() {
v.once.Do(func() {
v.validate = validator.New()
})
}
func kindOfData(data interface{}) reflect.Kind {
value := reflect.ValueOf(data)
valueType := value.Kind()
if valueType == reflect.Ptr {
valueType = value.Elem().Kind()
}
return valueType
}

View File

@ -0,0 +1,113 @@
// Code generated by protoc-gen-go.
// source: test.proto
// DO NOT EDIT!
/*
Package example is a generated protocol buffer package.
It is generated from these files:
test.proto
It has these top-level messages:
Test
*/
package example
import proto "github.com/golang/protobuf/proto"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = math.Inf
type FOO int32
const (
FOO_X FOO = 17
)
var FOO_name = map[int32]string{
17: "X",
}
var FOO_value = map[string]int32{
"X": 17,
}
func (x FOO) Enum() *FOO {
p := new(FOO)
*p = x
return p
}
func (x FOO) String() string {
return proto.EnumName(FOO_name, int32(x))
}
func (x *FOO) UnmarshalJSON(data []byte) error {
value, err := proto.UnmarshalJSONEnum(FOO_value, data, "FOO")
if err != nil {
return err
}
*x = FOO(value)
return nil
}
type Test struct {
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (m *Test) Reset() { *m = Test{} }
func (m *Test) String() string { return proto.CompactTextString(m) }
func (*Test) ProtoMessage() {}
const Default_Test_Type int32 = 77
func (m *Test) GetLabel() string {
if m != nil && m.Label != nil {
return *m.Label
}
return ""
}
func (m *Test) GetType() int32 {
if m != nil && m.Type != nil {
return *m.Type
}
return Default_Test_Type
}
func (m *Test) GetReps() []int64 {
if m != nil {
return m.Reps
}
return nil
}
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
if m != nil {
return m.Optionalgroup
}
return nil
}
type Test_OptionalGroup struct {
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
func (*Test_OptionalGroup) ProtoMessage() {}
func (m *Test_OptionalGroup) GetRequiredField() string {
if m != nil && m.RequiredField != nil {
return *m.RequiredField
}
return ""
}
func init() {
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
}

View File

@ -0,0 +1,12 @@
package example;
enum FOO {X=17;};
message Test {
required string label = 1;
optional int32 type = 2[default=77];
repeated int64 reps = 3;
optional group OptionalGroup = 4{
required string RequiredField = 5;
}
}

View File

@ -0,0 +1,36 @@
package binding
import (
"fmt"
"log"
"net/http"
)
type Arg struct {
Max int64 `form:"max" validate:"max=10"`
Min int64 `form:"min" validate:"min=2"`
Range int64 `form:"range" validate:"min=1,max=10"`
// use split option to split arg 1,2,3 into slice [1 2 3]
// otherwise slice type with parse url.Values (eg:a=b&a=c) default.
Slice []int64 `form:"slice,split" validate:"min=1"`
}
func ExampleBinding() {
req := initHTTP("max=9&min=3&range=3&slice=1,2,3")
arg := new(Arg)
if err := Form.Bind(req, arg); err != nil {
log.Fatal(err)
}
fmt.Printf("arg.Max %d\narg.Min %d\narg.Range %d\narg.Slice %v", arg.Max, arg.Min, arg.Range, arg.Slice)
// Output:
// arg.Max 9
// arg.Min 3
// arg.Range 3
// arg.Slice [1 2 3]
}
func initHTTP(params string) (req *http.Request) {
req, _ = http.NewRequest("GET", "http://api.bilibili.com/test?"+params, nil)
req.ParseForm()
return
}

View File

@ -0,0 +1,55 @@
package binding
import (
"net/http"
"github.com/pkg/errors"
)
const defaultMemory = 32 * 1024 * 1024
type formBinding struct{}
type formPostBinding struct{}
type formMultipartBinding struct{}
func (f formBinding) Name() string {
return "form"
}
func (f formBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return errors.WithStack(err)
}
if err := mapForm(obj, req.Form); err != nil {
return err
}
return validate(obj)
}
func (f formPostBinding) Name() string {
return "form-urlencoded"
}
func (f formPostBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseForm(); err != nil {
return errors.WithStack(err)
}
if err := mapForm(obj, req.PostForm); err != nil {
return err
}
return validate(obj)
}
func (f formMultipartBinding) Name() string {
return "multipart/form-data"
}
func (f formMultipartBinding) Bind(req *http.Request, obj interface{}) error {
if err := req.ParseMultipartForm(defaultMemory); err != nil {
return errors.WithStack(err)
}
if err := mapForm(obj, req.MultipartForm.Value); err != nil {
return err
}
return validate(obj)
}

View File

@ -0,0 +1,276 @@
package binding
import (
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/pkg/errors"
)
// scache struct reflect type cache.
var scache = &cache{
data: make(map[reflect.Type]*sinfo),
}
type cache struct {
data map[reflect.Type]*sinfo
mutex sync.RWMutex
}
func (c *cache) get(obj reflect.Type) (s *sinfo) {
var ok bool
c.mutex.RLock()
if s, ok = c.data[obj]; !ok {
c.mutex.RUnlock()
s = c.set(obj)
return
}
c.mutex.RUnlock()
return
}
func (c *cache) set(obj reflect.Type) (s *sinfo) {
s = new(sinfo)
tp := obj.Elem()
for i := 0; i < tp.NumField(); i++ {
fd := new(field)
fd.tp = tp.Field(i)
tag := fd.tp.Tag.Get("form")
fd.name, fd.option = parseTag(tag)
if defV := fd.tp.Tag.Get("default"); defV != "" {
dv := reflect.New(fd.tp.Type).Elem()
setWithProperType(fd.tp.Type.Kind(), []string{defV}, dv, fd.option)
fd.hasDefault = true
fd.defaultValue = dv
}
s.field = append(s.field, fd)
}
c.mutex.Lock()
c.data[obj] = s
c.mutex.Unlock()
return
}
type sinfo struct {
field []*field
}
type field struct {
tp reflect.StructField
name string
option tagOptions
hasDefault bool // if field had default value
defaultValue reflect.Value // field default value
}
func mapForm(ptr interface{}, form map[string][]string) error {
sinfo := scache.get(reflect.TypeOf(ptr))
val := reflect.ValueOf(ptr).Elem()
for i, fd := range sinfo.field {
typeField := fd.tp
structField := val.Field(i)
if !structField.CanSet() {
continue
}
structFieldKind := structField.Kind()
inputFieldName := fd.name
if inputFieldName == "" {
inputFieldName = typeField.Name
// if "form" tag is nil, we inspect if the field is a struct.
// this would not make sense for JSON parsing but it does for a form
// since data is flatten
if structFieldKind == reflect.Struct {
err := mapForm(structField.Addr().Interface(), form)
if err != nil {
return err
}
continue
}
}
inputValue, exists := form[inputFieldName]
if !exists {
// Set the field as default value when the input value is not exist
if fd.hasDefault {
structField.Set(fd.defaultValue)
}
continue
}
// Set the field as default value when the input value is empty
if fd.hasDefault && inputValue[0] == "" {
structField.Set(fd.defaultValue)
continue
}
if _, isTime := structField.Interface().(time.Time); isTime {
if err := setTimeField(inputValue[0], typeField, structField); err != nil {
return err
}
continue
}
if err := setWithProperType(typeField.Type.Kind(), inputValue, structField, fd.option); err != nil {
return err
}
}
return nil
}
func setWithProperType(valueKind reflect.Kind, val []string, structField reflect.Value, option tagOptions) error {
switch valueKind {
case reflect.Int:
return setIntField(val[0], 0, structField)
case reflect.Int8:
return setIntField(val[0], 8, structField)
case reflect.Int16:
return setIntField(val[0], 16, structField)
case reflect.Int32:
return setIntField(val[0], 32, structField)
case reflect.Int64:
return setIntField(val[0], 64, structField)
case reflect.Uint:
return setUintField(val[0], 0, structField)
case reflect.Uint8:
return setUintField(val[0], 8, structField)
case reflect.Uint16:
return setUintField(val[0], 16, structField)
case reflect.Uint32:
return setUintField(val[0], 32, structField)
case reflect.Uint64:
return setUintField(val[0], 64, structField)
case reflect.Bool:
return setBoolField(val[0], structField)
case reflect.Float32:
return setFloatField(val[0], 32, structField)
case reflect.Float64:
return setFloatField(val[0], 64, structField)
case reflect.String:
structField.SetString(val[0])
case reflect.Slice:
if option.Contains("split") {
val = strings.Split(val[0], ",")
}
filtered := filterEmpty(val)
switch structField.Type().Elem().Kind() {
case reflect.Int64:
valSli := make([]int64, 0, len(filtered))
for i := 0; i < len(filtered); i++ {
d, err := strconv.ParseInt(filtered[i], 10, 64)
if err != nil {
return err
}
valSli = append(valSli, d)
}
structField.Set(reflect.ValueOf(valSli))
case reflect.String:
valSli := make([]string, 0, len(filtered))
for i := 0; i < len(filtered); i++ {
valSli = append(valSli, filtered[i])
}
structField.Set(reflect.ValueOf(valSli))
default:
sliceOf := structField.Type().Elem().Kind()
numElems := len(filtered)
slice := reflect.MakeSlice(structField.Type(), len(filtered), len(filtered))
for i := 0; i < numElems; i++ {
if err := setWithProperType(sliceOf, filtered[i:], slice.Index(i), ""); err != nil {
return err
}
}
structField.Set(slice)
}
default:
return errors.New("Unknown type")
}
return nil
}
func setIntField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0"
}
intVal, err := strconv.ParseInt(val, 10, bitSize)
if err == nil {
field.SetInt(intVal)
}
return errors.WithStack(err)
}
func setUintField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0"
}
uintVal, err := strconv.ParseUint(val, 10, bitSize)
if err == nil {
field.SetUint(uintVal)
}
return errors.WithStack(err)
}
func setBoolField(val string, field reflect.Value) error {
if val == "" {
val = "false"
}
boolVal, err := strconv.ParseBool(val)
if err == nil {
field.SetBool(boolVal)
}
return nil
}
func setFloatField(val string, bitSize int, field reflect.Value) error {
if val == "" {
val = "0.0"
}
floatVal, err := strconv.ParseFloat(val, bitSize)
if err == nil {
field.SetFloat(floatVal)
}
return errors.WithStack(err)
}
func setTimeField(val string, structField reflect.StructField, value reflect.Value) error {
timeFormat := structField.Tag.Get("time_format")
if timeFormat == "" {
return errors.New("Blank time format")
}
if val == "" {
value.Set(reflect.ValueOf(time.Time{}))
return nil
}
l := time.Local
if isUTC, _ := strconv.ParseBool(structField.Tag.Get("time_utc")); isUTC {
l = time.UTC
}
if locTag := structField.Tag.Get("time_location"); locTag != "" {
loc, err := time.LoadLocation(locTag)
if err != nil {
return errors.WithStack(err)
}
l = loc
}
t, err := time.ParseInLocation(timeFormat, val, l)
if err != nil {
return errors.WithStack(err)
}
value.Set(reflect.ValueOf(t))
return nil
}
func filterEmpty(val []string) []string {
filtered := make([]string, 0, len(val))
for _, v := range val {
if v != "" {
filtered = append(filtered, v)
}
}
return filtered
}

View File

@ -0,0 +1,22 @@
package binding
import (
"encoding/json"
"net/http"
"github.com/pkg/errors"
)
type jsonBinding struct{}
func (jsonBinding) Name() string {
return "json"
}
func (jsonBinding) Bind(req *http.Request, obj interface{}) error {
decoder := json.NewDecoder(req.Body)
if err := decoder.Decode(obj); err != nil {
return errors.WithStack(err)
}
return validate(obj)
}

View File

@ -0,0 +1,19 @@
package binding
import (
"net/http"
)
type queryBinding struct{}
func (queryBinding) Name() string {
return "query"
}
func (queryBinding) Bind(req *http.Request, obj interface{}) error {
values := req.URL.Query()
if err := mapForm(obj, values); err != nil {
return err
}
return validate(obj)
}

View File

@ -0,0 +1,44 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package binding
import (
"strings"
)
// tagOptions is the string following a comma in a struct field's "json"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's json tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
if idx := strings.Index(tag, ","); idx != -1 {
return tag[:idx], tagOptions(tag[idx+1:])
}
return tag, tagOptions("")
}
// Contains reports whether a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var next string
i := strings.Index(s, ",")
if i >= 0 {
s, next = s[:i], s[i+1:]
}
if s == optionName {
return true
}
s = next
}
return false
}

View File

@ -0,0 +1,209 @@
package binding
import (
"bytes"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type testInterface interface {
String() string
}
type substructNoValidation struct {
IString string
IInt int
}
type mapNoValidationSub map[string]substructNoValidation
type structNoValidationValues struct {
substructNoValidation
Boolean bool
Uinteger uint
Integer int
Integer8 int8
Integer16 int16
Integer32 int32
Integer64 int64
Uinteger8 uint8
Uinteger16 uint16
Uinteger32 uint32
Uinteger64 uint64
Float32 float32
Float64 float64
String string
Date time.Time
Struct substructNoValidation
InlinedStruct struct {
String []string
Integer int
}
IntSlice []int
IntPointerSlice []*int
StructPointerSlice []*substructNoValidation
StructSlice []substructNoValidation
InterfaceSlice []testInterface
UniversalInterface interface{}
CustomInterface testInterface
FloatMap map[string]float32
StructMap mapNoValidationSub
}
func createNoValidationValues() structNoValidationValues {
integer := 1
s := structNoValidationValues{
Boolean: true,
Uinteger: 1 << 29,
Integer: -10000,
Integer8: 120,
Integer16: -20000,
Integer32: 1 << 29,
Integer64: 1 << 61,
Uinteger8: 250,
Uinteger16: 50000,
Uinteger32: 1 << 31,
Uinteger64: 1 << 62,
Float32: 123.456,
Float64: 123.456789,
String: "text",
Date: time.Time{},
CustomInterface: &bytes.Buffer{},
Struct: substructNoValidation{},
IntSlice: []int{-3, -2, 1, 0, 1, 2, 3},
IntPointerSlice: []*int{&integer},
StructSlice: []substructNoValidation{},
UniversalInterface: 1.2,
FloatMap: map[string]float32{
"foo": 1.23,
"bar": 232.323,
},
StructMap: mapNoValidationSub{
"foo": substructNoValidation{},
"bar": substructNoValidation{},
},
// StructPointerSlice []noValidationSub
// InterfaceSlice []testInterface
}
s.InlinedStruct.Integer = 1000
s.InlinedStruct.String = []string{"first", "second"}
s.IString = "substring"
s.IInt = 987654
return s
}
func TestValidateNoValidationValues(t *testing.T) {
origin := createNoValidationValues()
test := createNoValidationValues()
empty := structNoValidationValues{}
assert.Nil(t, validate(test))
assert.Nil(t, validate(&test))
assert.Nil(t, validate(empty))
assert.Nil(t, validate(&empty))
assert.Equal(t, origin, test)
}
type structNoValidationPointer struct {
// substructNoValidation
Boolean bool
Uinteger *uint
Integer *int
Integer8 *int8
Integer16 *int16
Integer32 *int32
Integer64 *int64
Uinteger8 *uint8
Uinteger16 *uint16
Uinteger32 *uint32
Uinteger64 *uint64
Float32 *float32
Float64 *float64
String *string
Date *time.Time
Struct *substructNoValidation
IntSlice *[]int
IntPointerSlice *[]*int
StructPointerSlice *[]*substructNoValidation
StructSlice *[]substructNoValidation
InterfaceSlice *[]testInterface
FloatMap *map[string]float32
StructMap *mapNoValidationSub
}
func TestValidateNoValidationPointers(t *testing.T) {
//origin := createNoValidation_values()
//test := createNoValidation_values()
empty := structNoValidationPointer{}
//assert.Nil(t, validate(test))
//assert.Nil(t, validate(&test))
assert.Nil(t, validate(empty))
assert.Nil(t, validate(&empty))
//assert.Equal(t, origin, test)
}
type Object map[string]interface{}
func TestValidatePrimitives(t *testing.T) {
obj := Object{"foo": "bar", "bar": 1}
assert.NoError(t, validate(obj))
assert.NoError(t, validate(&obj))
assert.Equal(t, obj, Object{"foo": "bar", "bar": 1})
obj2 := []Object{{"foo": "bar", "bar": 1}, {"foo": "bar", "bar": 1}}
assert.NoError(t, validate(obj2))
assert.NoError(t, validate(&obj2))
nu := 10
assert.NoError(t, validate(nu))
assert.NoError(t, validate(&nu))
assert.Equal(t, nu, 10)
str := "value"
assert.NoError(t, validate(str))
assert.NoError(t, validate(&str))
assert.Equal(t, str, "value")
}
// structCustomValidation is a helper struct we use to check that
// custom validation can be registered on it.
// The `notone` binding directive is for custom validation and registered later.
// type structCustomValidation struct {
// Integer int `binding:"notone"`
// }
// notOne is a custom validator meant to be used with `validator.v8` library.
// The method signature for `v9` is significantly different and this function
// would need to be changed for tests to pass after upgrade.
// See https://github.com/gin-gonic/gin/pull/1015.
// func notOne(
// v *validator.Validate, topStruct reflect.Value, currentStructOrField reflect.Value,
// field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string,
// ) bool {
// if val, ok := field.Interface().(int); ok {
// return val != 1
// }
// return false
// }

View File

@ -0,0 +1,22 @@
package binding
import (
"encoding/xml"
"net/http"
"github.com/pkg/errors"
)
type xmlBinding struct{}
func (xmlBinding) Name() string {
return "xml"
}
func (xmlBinding) Bind(req *http.Request, obj interface{}) error {
decoder := xml.NewDecoder(req.Body)
if err := decoder.Decode(obj); err != nil {
return errors.WithStack(err)
}
return validate(obj)
}

View File

@ -0,0 +1,306 @@
package blademaster
import (
"context"
"math"
"net/http"
"strconv"
"github.com/bilibili/Kratos/pkg/ecode"
"github.com/bilibili/Kratos/pkg/net/http/blademaster/binding"
"github.com/bilibili/Kratos/pkg/net/http/blademaster/render"
"github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/types"
"github.com/pkg/errors"
)
const (
_abortIndex int8 = math.MaxInt8 / 2
)
var (
_openParen = []byte("(")
_closeParen = []byte(")")
)
// Context is the most important part. It allows us to pass variables between
// middleware, manage the flow, validate the JSON of a request and render a
// JSON response for example.
type Context struct {
context.Context
Request *http.Request
Writer http.ResponseWriter
// flow control
index int8
handlers []HandlerFunc
// Keys is a key/value pair exclusively for the context of each request.
Keys map[string]interface{}
Error error
method string
engine *Engine
}
/************************************/
/*********** FLOW CONTROL ***********/
/************************************/
// Next should be used only inside middleware.
// It executes the pending handlers in the chain inside the calling handler.
// See example in godoc.
func (c *Context) Next() {
c.index++
s := int8(len(c.handlers))
for ; c.index < s; c.index++ {
// only check method on last handler, otherwise middlewares
// will never be effected if request method is not matched
if c.index == s-1 && c.method != c.Request.Method {
code := http.StatusMethodNotAllowed
c.Error = ecode.MethodNotAllowed
http.Error(c.Writer, http.StatusText(code), code)
return
}
c.handlers[c.index](c)
}
}
// Abort prevents pending handlers from being called. Note that this will not stop the current handler.
// Let's say you have an authorization middleware that validates that the current request is authorized.
// If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers
// for this request are not called.
func (c *Context) Abort() {
c.index = _abortIndex
}
// AbortWithStatus calls `Abort()` and writes the headers with the specified status code.
// For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401).
func (c *Context) AbortWithStatus(code int) {
c.Status(code)
c.Abort()
}
// IsAborted returns true if the current context was aborted.
func (c *Context) IsAborted() bool {
return c.index >= _abortIndex
}
/************************************/
/******** METADATA MANAGEMENT********/
/************************************/
// Set is used to store a new key/value pair exclusively for this context.
// It also lazy initializes c.Keys if it was not used previously.
func (c *Context) Set(key string, value interface{}) {
if c.Keys == nil {
c.Keys = make(map[string]interface{})
}
c.Keys[key] = value
}
// Get returns the value for the given key, ie: (value, true).
// If the value does not exists it returns (nil, false)
func (c *Context) Get(key string) (value interface{}, exists bool) {
value, exists = c.Keys[key]
return
}
/************************************/
/******** RESPONSE RENDERING ********/
/************************************/
// bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}
// Status sets the HTTP response code.
func (c *Context) Status(code int) {
c.Writer.WriteHeader(code)
}
// Render http response with http code by a render instance.
func (c *Context) Render(code int, r render.Render) {
r.WriteContentType(c.Writer)
if code > 0 {
c.Status(code)
}
if !bodyAllowedForStatus(code) {
return
}
params := c.Request.Form
cb := params.Get("callback")
jsonp := cb != "" && params.Get("jsonp") == "jsonp"
if jsonp {
c.Writer.Write([]byte(cb))
c.Writer.Write(_openParen)
}
if err := r.Render(c.Writer); err != nil {
c.Error = err
return
}
if jsonp {
if _, err := c.Writer.Write(_closeParen); err != nil {
c.Error = errors.WithStack(err)
}
}
}
// JSON serializes the given struct as JSON into the response body.
// It also sets the Content-Type as "application/json".
func (c *Context) JSON(data interface{}, err error) {
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
// TODO app allow 5xx?
/*
if bcode.Code() == -500 {
code = http.StatusServiceUnavailable
}
*/
writeStatusCode(c.Writer, bcode.Code())
c.Render(code, render.JSON{
Code: bcode.Code(),
Message: bcode.Message(),
Data: data,
})
}
// JSONMap serializes the given map as map JSON into the response body.
// It also sets the Content-Type as "application/json".
func (c *Context) JSONMap(data map[string]interface{}, err error) {
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
// TODO app allow 5xx?
/*
if bcode.Code() == -500 {
code = http.StatusServiceUnavailable
}
*/
writeStatusCode(c.Writer, bcode.Code())
data["code"] = bcode.Code()
if _, ok := data["message"]; !ok {
data["message"] = bcode.Message()
}
c.Render(code, render.MapJSON(data))
}
// XML serializes the given struct as XML into the response body.
// It also sets the Content-Type as "application/xml".
func (c *Context) XML(data interface{}, err error) {
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
// TODO app allow 5xx?
/*
if bcode.Code() == -500 {
code = http.StatusServiceUnavailable
}
*/
writeStatusCode(c.Writer, bcode.Code())
c.Render(code, render.XML{
Code: bcode.Code(),
Message: bcode.Message(),
Data: data,
})
}
// Protobuf serializes the given struct as PB into the response body.
// It also sets the ContentType as "application/x-protobuf".
func (c *Context) Protobuf(data proto.Message, err error) {
var (
bytes []byte
)
code := http.StatusOK
c.Error = err
bcode := ecode.Cause(err)
any := new(types.Any)
if data != nil {
if bytes, err = proto.Marshal(data); err != nil {
c.Error = errors.WithStack(err)
return
}
any.TypeUrl = "type.googleapis.com/" + proto.MessageName(data)
any.Value = bytes
}
writeStatusCode(c.Writer, bcode.Code())
c.Render(code, render.PB{
Code: int64(bcode.Code()),
Message: bcode.Message(),
Data: any,
})
}
// Bytes writes some data into the body stream and updates the HTTP code.
func (c *Context) Bytes(code int, contentType string, data ...[]byte) {
c.Render(code, render.Data{
ContentType: contentType,
Data: data,
})
}
// String writes the given string into the response body.
func (c *Context) String(code int, format string, values ...interface{}) {
c.Render(code, render.String{Format: format, Data: values})
}
// Redirect returns a HTTP redirect to the specific location.
func (c *Context) Redirect(code int, location string) {
c.Render(-1, render.Redirect{
Code: code,
Location: location,
Request: c.Request,
})
}
// BindWith bind req arg with parser.
func (c *Context) BindWith(obj interface{}, b binding.Binding) error {
return c.mustBindWith(obj, b)
}
// Bind bind req arg with defult form binding.
func (c *Context) Bind(obj interface{}) error {
return c.mustBindWith(obj, binding.Form)
}
// mustBindWith binds the passed struct pointer using the specified binding engine.
// It will abort the request with HTTP 400 if any error ocurrs.
// See the binding package.
func (c *Context) mustBindWith(obj interface{}, b binding.Binding) (err error) {
if err = b.Bind(c.Request, obj); err != nil {
c.Error = ecode.RequestErr
c.Render(http.StatusOK, render.JSON{
Code: ecode.RequestErr.Code(),
Message: err.Error(),
Data: nil,
})
c.Abort()
}
return
}
func writeStatusCode(w http.ResponseWriter, ecode int) {
header := w.Header()
header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10))
}

View File

@ -0,0 +1,249 @@
package blademaster
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/bilibili/Kratos/pkg/log"
"github.com/pkg/errors"
)
// CORSConfig represents all available options for the middleware.
type CORSConfig struct {
AllowAllOrigins bool
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
// If the special "*" value is present in the list, all origins will be allowed.
// Default value is []
AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It take the origin
// as argument and returns true if allowed or false otherwise. If this option is
// set, the content of AllowedOrigins is ignored.
AllowOriginFunc func(origin string) bool
// AllowedMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (GET and POST)
AllowMethods []string
// AllowedHeaders is list of non simple headers the client is allowed to use with
// cross-domain requests.
AllowHeaders []string
// AllowCredentials indicates whether the request can include user credentials like
// cookies, HTTP authentication or client side SSL certificates.
AllowCredentials bool
// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
// API specification
ExposeHeaders []string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached
MaxAge time.Duration
}
type cors struct {
allowAllOrigins bool
allowCredentials bool
allowOriginFunc func(string) bool
allowOrigins []string
normalHeaders http.Header
preflightHeaders http.Header
}
type converter func(string) string
// Validate is check configuration of user defined.
func (c *CORSConfig) Validate() error {
if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
}
if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
return errors.New("conflict settings: all origins disabled")
}
for _, origin := range c.AllowOrigins {
if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
return errors.New("bad origin: origins must either be '*' or include http:// or https://")
}
}
return nil
}
// CORS returns the location middleware with default configuration.
func CORS(allowOriginHosts []string) HandlerFunc {
config := &CORSConfig{
AllowMethods: []string{"GET", "POST"},
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"},
AllowCredentials: true,
MaxAge: time.Duration(0),
AllowOriginFunc: func(origin string) bool {
for _, host := range allowOriginHosts {
if strings.HasSuffix(strings.ToLower(origin), host) {
return true
}
}
return false
},
}
return newCORS(config)
}
// newCORS returns the location middleware with user-defined custom configuration.
func newCORS(config *CORSConfig) HandlerFunc {
if err := config.Validate(); err != nil {
panic(err.Error())
}
cors := &cors{
allowOriginFunc: config.AllowOriginFunc,
allowAllOrigins: config.AllowAllOrigins,
allowCredentials: config.AllowCredentials,
allowOrigins: normalize(config.AllowOrigins),
normalHeaders: generateNormalHeaders(config),
preflightHeaders: generatePreflightHeaders(config),
}
return func(c *Context) {
cors.applyCORS(c)
}
}
func (cors *cors) applyCORS(c *Context) {
origin := c.Request.Header.Get("Origin")
if len(origin) == 0 {
// request is not a CORS request
return
}
if !cors.validateOrigin(origin) {
log.V(5).Info("The request's Origin header `%s` does not match any of allowed origins.", origin)
c.AbortWithStatus(http.StatusForbidden)
return
}
if c.Request.Method == "OPTIONS" {
cors.handlePreflight(c)
defer c.AbortWithStatus(200)
} else {
cors.handleNormal(c)
}
if !cors.allowAllOrigins {
header := c.Writer.Header()
header.Set("Access-Control-Allow-Origin", origin)
}
}
func (cors *cors) validateOrigin(origin string) bool {
if cors.allowAllOrigins {
return true
}
for _, value := range cors.allowOrigins {
if value == origin {
return true
}
}
if cors.allowOriginFunc != nil {
return cors.allowOriginFunc(origin)
}
return false
}
func (cors *cors) handlePreflight(c *Context) {
header := c.Writer.Header()
for key, value := range cors.preflightHeaders {
header[key] = value
}
}
func (cors *cors) handleNormal(c *Context) {
header := c.Writer.Header()
for key, value := range cors.normalHeaders {
header[key] = value
}
}
func generateNormalHeaders(c *CORSConfig) http.Header {
headers := make(http.Header)
if c.AllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
// backport support for early browsers
if len(c.AllowMethods) > 0 {
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
value := strings.Join(allowMethods, ",")
headers.Set("Access-Control-Allow-Methods", value)
}
if len(c.ExposeHeaders) > 0 {
exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey)
headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ","))
}
if c.AllowAllOrigins {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
headers.Set("Vary", "Origin")
}
return headers
}
func generatePreflightHeaders(c *CORSConfig) http.Header {
headers := make(http.Header)
if c.AllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
if len(c.AllowMethods) > 0 {
allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
value := strings.Join(allowMethods, ",")
headers.Set("Access-Control-Allow-Methods", value)
}
if len(c.AllowHeaders) > 0 {
allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey)
value := strings.Join(allowHeaders, ",")
headers.Set("Access-Control-Allow-Headers", value)
}
if c.MaxAge > time.Duration(0) {
value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10)
headers.Set("Access-Control-Max-Age", value)
}
if c.AllowAllOrigins {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
// Always set Vary headers
// see https://github.com/rs/cors/issues/10,
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
headers.Add("Vary", "Origin")
headers.Add("Vary", "Access-Control-Request-Method")
headers.Add("Vary", "Access-Control-Request-Headers")
}
return headers
}
func normalize(values []string) []string {
if values == nil {
return nil
}
distinctMap := make(map[string]bool, len(values))
normalized := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
value = strings.ToLower(value)
if _, seen := distinctMap[value]; !seen {
normalized = append(normalized, value)
distinctMap[value] = true
}
}
return normalized
}
func convert(s []string, c converter) []string {
var out []string
for _, i := range s {
out = append(out, c(i))
}
return out
}

View File

@ -0,0 +1,69 @@
package blademaster
import (
"net/url"
"regexp"
"strings"
"github.com/bilibili/Kratos/pkg/log"
)
func matchHostSuffix(suffix string) func(*url.URL) bool {
return func(uri *url.URL) bool {
return strings.HasSuffix(strings.ToLower(uri.Host), suffix)
}
}
func matchPattern(pattern *regexp.Regexp) func(*url.URL) bool {
return func(uri *url.URL) bool {
return pattern.MatchString(strings.ToLower(uri.String()))
}
}
// CSRF returns the csrf middleware to prevent invalid cross site request.
// Only referer is checked currently.
func CSRF(allowHosts []string, allowPattern []string) HandlerFunc {
validations := []func(*url.URL) bool{}
addHostSuffix := func(suffix string) {
validations = append(validations, matchHostSuffix(suffix))
}
addPattern := func(pattern string) {
validations = append(validations, matchPattern(regexp.MustCompile(pattern)))
}
for _, r := range allowHosts {
addHostSuffix(r)
}
for _, p := range allowPattern {
addPattern(p)
}
return func(c *Context) {
referer := c.Request.Header.Get("Referer")
params := c.Request.Form
cross := (params.Get("callback") != "" && params.Get("jsonp") == "jsonp") || (params.Get("cross_domain") != "")
if referer == "" {
if !cross {
return
}
log.V(5).Info("The request's Referer header is empty.")
c.AbortWithStatus(403)
return
}
illegal := true
if uri, err := url.Parse(referer); err == nil && uri.Host != "" {
for _, validate := range validations {
if validate(uri) {
illegal = false
break
}
}
}
if illegal {
log.V(5).Info("The request's Referer header `%s` does not match any of allowed referers.", referer)
c.AbortWithStatus(403)
return
}
}
}

View File

@ -0,0 +1,69 @@
package blademaster
import (
"fmt"
"strconv"
"time"
"github.com/bilibili/Kratos/pkg/ecode"
"github.com/bilibili/Kratos/pkg/log"
"github.com/bilibili/Kratos/pkg/net/metadata"
)
// Logger is logger middleware
func Logger() HandlerFunc {
const noUser = "no_user"
return func(c *Context) {
now := time.Now()
ip := metadata.String(c, metadata.RemoteIP)
req := c.Request
path := req.URL.Path
params := req.Form
var quota float64
if deadline, ok := c.Context.Deadline(); ok {
quota = time.Until(deadline).Seconds()
}
c.Next()
err := c.Error
cerr := ecode.Cause(err)
dt := time.Since(now)
caller := metadata.String(c, metadata.Caller)
if caller == "" {
caller = noUser
}
stats.Incr(caller, path[1:], strconv.FormatInt(int64(cerr.Code()), 10))
stats.Timing(caller, int64(dt/time.Millisecond), path[1:])
lf := log.Infov
errmsg := ""
isSlow := dt >= (time.Millisecond * 500)
if err != nil {
errmsg = err.Error()
lf = log.Errorv
if cerr.Code() > 0 {
lf = log.Warnv
}
} else {
if isSlow {
lf = log.Warnv
}
}
lf(c,
log.KVString("method", req.Method),
log.KVString("ip", ip),
log.KVString("user", caller),
log.KVString("path", path),
log.KVString("params", params.Encode()),
log.KVInt("ret", cerr.Code()),
log.KVString("msg", cerr.Message()),
log.KVString("stack", fmt.Sprintf("%+v", err)),
log.KVString("err", errmsg),
log.KVFloat64("timeout_quota", quota),
log.KVFloat64("ts", dt.Seconds()),
log.KVString("source", "http-access-log"),
)
}
}

View File

@ -17,13 +17,14 @@ const (
_httpHeaderUser = "x1-bmspy-user" _httpHeaderUser = "x1-bmspy-user"
_httpHeaderColor = "x1-bmspy-color" _httpHeaderColor = "x1-bmspy-color"
_httpHeaderTimeout = "x1-bmspy-timeout" _httpHeaderTimeout = "x1-bmspy-timeout"
_httpHeaderMirror = "x1-bmspy-mirror"
_httpHeaderRemoteIP = "x-backend-bm-real-ip" _httpHeaderRemoteIP = "x-backend-bm-real-ip"
_httpHeaderRemoteIPPort = "x-backend-bm-real-ipport" _httpHeaderRemoteIPPort = "x-backend-bm-real-ipport"
) )
// mirror return true if x1-bilispy-mirror in http header and its value is 1 or true. // mirror return true if x-bmspy-mirror in http header and its value is 1 or true.
func mirror(req *http.Request) bool { func mirror(req *http.Request) bool {
mirrorStr := req.Header.Get("x1-bilispy-mirror") mirrorStr := req.Header.Get(_httpHeaderMirror)
if mirrorStr == "" { if mirrorStr == "" {
return false return false
} }
@ -79,7 +80,7 @@ func timeout(req *http.Request) time.Duration {
} }
// remoteIP implements a best effort algorithm to return the real client IP, it parses // remoteIP implements a best effort algorithm to return the real client IP, it parses
// X-BACKEND-BILI-REAL-IP or X-Real-IP or X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy. // x-backend-bm-real-ip or X-Real-IP or X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
// Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP. // Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP.
func remoteIP(req *http.Request) (remote string) { func remoteIP(req *http.Request) (remote string) {
if remote = req.Header.Get(_httpHeaderRemoteIP); remote != "" && remote != "null" { if remote = req.Header.Get(_httpHeaderRemoteIP); remote != "" && remote != "null" {

View File

@ -0,0 +1,46 @@
package blademaster
import (
"flag"
"net/http"
"net/http/pprof"
"os"
"sync"
"github.com/bilibili/Kratos/pkg/conf/dsn"
"github.com/pkg/errors"
)
var (
_perfOnce sync.Once
_perfDSN string
)
func init() {
v := os.Getenv("HTTP_PERF")
if v == "" {
v = "tcp://0.0.0.0:2333"
}
flag.StringVar(&_perfDSN, "http.perf", v, "listen http perf dsn, or use HTTP_PERF env variable.")
}
func startPerf() {
_perfOnce.Do(func() {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
go func() {
d, err := dsn.Parse(_perfDSN)
if err != nil {
panic(errors.Errorf("blademaster: http perf dsn must be tcp://$host:port, %s:error(%v)", _perfDSN, err))
}
if err := http.ListenAndServe(d.Host, mux); err != nil {
panic(errors.Errorf("blademaster: listen %s: error(%v)", d.Host, err))
}
}()
})
}

View File

@ -0,0 +1,12 @@
package blademaster
import (
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func monitor() HandlerFunc {
return func(c *Context) {
h := promhttp.Handler()
h.ServeHTTP(c.Writer, c.Request)
}
}

View File

@ -0,0 +1,32 @@
package blademaster
import (
"fmt"
"net/http/httputil"
"os"
"runtime"
"github.com/bilibili/Kratos/pkg/log"
)
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
func Recovery() HandlerFunc {
return func(c *Context) {
defer func() {
var rawReq []byte
if err := recover(); err != nil {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
if c.Request != nil {
rawReq, _ = httputil.DumpRequest(c.Request, false)
}
pl := fmt.Sprintf("http call panic: %s\n%v\n%s\n", string(rawReq), err, buf)
fmt.Fprintf(os.Stderr, pl)
log.Error(pl)
c.AbortWithStatus(500)
}
}()
c.Next()
}
}

View File

@ -0,0 +1,30 @@
package render
import (
"net/http"
"github.com/pkg/errors"
)
// Data common bytes struct.
type Data struct {
ContentType string
Data [][]byte
}
// Render (Data) writes data with custom ContentType.
func (r Data) Render(w http.ResponseWriter) (err error) {
r.WriteContentType(w)
for _, d := range r.Data {
if _, err = w.Write(d); err != nil {
err = errors.WithStack(err)
return
}
}
return
}
// WriteContentType writes data with custom ContentType.
func (r Data) WriteContentType(w http.ResponseWriter) {
writeContentType(w, []string{r.ContentType})
}

View File

@ -0,0 +1,58 @@
package render
import (
"encoding/json"
"net/http"
"github.com/pkg/errors"
)
var jsonContentType = []string{"application/json; charset=utf-8"}
// JSON common json struct.
type JSON struct {
Code int `json:"code"`
Message string `json:"message"`
TTL int `json:"ttl"`
Data interface{} `json:"data,omitempty"`
}
func writeJSON(w http.ResponseWriter, obj interface{}) (err error) {
var jsonBytes []byte
writeContentType(w, jsonContentType)
if jsonBytes, err = json.Marshal(obj); err != nil {
err = errors.WithStack(err)
return
}
if _, err = w.Write(jsonBytes); err != nil {
err = errors.WithStack(err)
}
return
}
// Render (JSON) writes data with json ContentType.
func (r JSON) Render(w http.ResponseWriter) error {
// FIXME(zhoujiahui): the TTL field will be configurable in the future
if r.TTL <= 0 {
r.TTL = 1
}
return writeJSON(w, r)
}
// WriteContentType write json ContentType.
func (r JSON) WriteContentType(w http.ResponseWriter) {
writeContentType(w, jsonContentType)
}
// MapJSON common map json struct.
type MapJSON map[string]interface{}
// Render (MapJSON) writes data with json ContentType.
func (m MapJSON) Render(w http.ResponseWriter) error {
return writeJSON(w, m)
}
// WriteContentType write json ContentType.
func (m MapJSON) WriteContentType(w http.ResponseWriter) {
writeContentType(w, jsonContentType)
}

View File

@ -0,0 +1,38 @@
package render
import (
"net/http"
"github.com/gogo/protobuf/proto"
"github.com/pkg/errors"
)
var pbContentType = []string{"application/x-protobuf"}
// Render (PB) writes data with protobuf ContentType.
func (r PB) Render(w http.ResponseWriter) error {
if r.TTL <= 0 {
r.TTL = 1
}
return writePB(w, r)
}
// WriteContentType write protobuf ContentType.
func (r PB) WriteContentType(w http.ResponseWriter) {
writeContentType(w, pbContentType)
}
func writePB(w http.ResponseWriter, obj PB) (err error) {
var pbBytes []byte
writeContentType(w, pbContentType)
if pbBytes, err = proto.Marshal(&obj); err != nil {
err = errors.WithStack(err)
return
}
if _, err = w.Write(pbBytes); err != nil {
err = errors.WithStack(err)
}
return
}

View File

@ -0,0 +1,26 @@
package render
import (
"net/http"
"github.com/pkg/errors"
)
// Redirect render for redirect to specified location.
type Redirect struct {
Code int
Request *http.Request
Location string
}
// Render (Redirect) redirect to specified location.
func (r Redirect) Render(w http.ResponseWriter) error {
if (r.Code < 300 || r.Code > 308) && r.Code != 201 {
return errors.Errorf("Cannot redirect with status code %d", r.Code)
}
http.Redirect(w, r.Request, r.Location, r.Code)
return nil
}
// WriteContentType noneContentType.
func (r Redirect) WriteContentType(http.ResponseWriter) {}

View File

@ -0,0 +1,30 @@
package render
import (
"net/http"
)
// Render http reponse render.
type Render interface {
// Render render it to http response writer.
Render(http.ResponseWriter) error
// WriteContentType write content-type to http response writer.
WriteContentType(w http.ResponseWriter)
}
var (
_ Render = JSON{}
_ Render = MapJSON{}
_ Render = XML{}
_ Render = String{}
_ Render = Redirect{}
_ Render = Data{}
_ Render = PB{}
)
func writeContentType(w http.ResponseWriter, value []string) {
header := w.Header()
if val := header["Content-Type"]; len(val) == 0 {
header["Content-Type"] = value
}
}

View File

@ -0,0 +1,89 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: pb.proto
/*
Package render is a generated protocol buffer package.
It is generated from these files:
pb.proto
It has these top-level messages:
PB
*/
package render
import proto "github.com/gogo/protobuf/proto"
import fmt "fmt"
import math "math"
import google_protobuf "github.com/gogo/protobuf/types"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package
type PB struct {
Code int64 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"`
Message string `protobuf:"bytes,2,opt,name=Message,proto3" json:"Message,omitempty"`
TTL uint64 `protobuf:"varint,3,opt,name=TTL,proto3" json:"TTL,omitempty"`
Data *google_protobuf.Any `protobuf:"bytes,4,opt,name=Data" json:"Data,omitempty"`
}
func (m *PB) Reset() { *m = PB{} }
func (m *PB) String() string { return proto.CompactTextString(m) }
func (*PB) ProtoMessage() {}
func (*PB) Descriptor() ([]byte, []int) { return fileDescriptorPb, []int{0} }
func (m *PB) GetCode() int64 {
if m != nil {
return m.Code
}
return 0
}
func (m *PB) GetMessage() string {
if m != nil {
return m.Message
}
return ""
}
func (m *PB) GetTTL() uint64 {
if m != nil {
return m.TTL
}
return 0
}
func (m *PB) GetData() *google_protobuf.Any {
if m != nil {
return m.Data
}
return nil
}
func init() {
proto.RegisterType((*PB)(nil), "render.PB")
}
func init() { proto.RegisterFile("pb.proto", fileDescriptorPb) }
var fileDescriptorPb = []byte{
// 154 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x28, 0x48, 0xd2, 0x2b,
0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0x4a, 0xcd, 0x4b, 0x49, 0x2d, 0x92, 0x92, 0x4c, 0xcf,
0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, 0x8b, 0x26, 0x95, 0xa6, 0xe9, 0x27, 0xe6, 0x55, 0x42, 0x94,
0x28, 0xe5, 0x71, 0x31, 0x05, 0x38, 0x09, 0x09, 0x71, 0xb1, 0x38, 0xe7, 0xa7, 0xa4, 0x4a, 0x30,
0x2a, 0x30, 0x6a, 0x30, 0x07, 0x81, 0xd9, 0x42, 0x12, 0x5c, 0xec, 0xbe, 0xa9, 0xc5, 0xc5, 0x89,
0xe9, 0xa9, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x90, 0x00, 0x17, 0x73, 0x48,
0x88, 0x8f, 0x04, 0xb3, 0x02, 0xa3, 0x06, 0x4b, 0x10, 0x88, 0x29, 0xa4, 0xc1, 0xc5, 0xe2, 0x92,
0x58, 0x92, 0x28, 0xc1, 0xa2, 0xc0, 0xa8, 0xc1, 0x6d, 0x24, 0xa2, 0x07, 0xb1, 0x4f, 0x0f, 0x66,
0x9f, 0x9e, 0x63, 0x5e, 0x65, 0x10, 0x58, 0x45, 0x12, 0x1b, 0x58, 0xcc, 0x18, 0x10, 0x00, 0x00,
0xff, 0xff, 0x7a, 0x92, 0x16, 0x71, 0xa5, 0x00, 0x00, 0x00,
}

View File

@ -0,0 +1,14 @@
// use under command to generate pb.pb.go
// protoc --proto_path=.:$GOPATH/src/github.com/gogo/protobuf --gogo_out=Mgoogle/protobuf/any.proto=github.com/gogo/protobuf/types:. *.proto
syntax = "proto3";
package render;
import "google/protobuf/any.proto";
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
message PB {
int64 Code = 1;
string Message = 2;
uint64 TTL = 3;
google.protobuf.Any Data = 4;
}

View File

@ -0,0 +1,40 @@
package render
import (
"fmt"
"io"
"net/http"
"github.com/pkg/errors"
)
var plainContentType = []string{"text/plain; charset=utf-8"}
// String common string struct.
type String struct {
Format string
Data []interface{}
}
// Render (String) writes data with custom ContentType.
func (r String) Render(w http.ResponseWriter) error {
return writeString(w, r.Format, r.Data)
}
// WriteContentType writes string with text/plain ContentType.
func (r String) WriteContentType(w http.ResponseWriter) {
writeContentType(w, plainContentType)
}
func writeString(w http.ResponseWriter, format string, data []interface{}) (err error) {
writeContentType(w, plainContentType)
if len(data) > 0 {
_, err = fmt.Fprintf(w, format, data...)
} else {
_, err = io.WriteString(w, format)
}
if err != nil {
err = errors.WithStack(err)
}
return
}

View File

@ -0,0 +1,31 @@
package render
import (
"encoding/xml"
"net/http"
"github.com/pkg/errors"
)
// XML common xml struct.
type XML struct {
Code int
Message string
Data interface{}
}
var xmlContentType = []string{"application/xml; charset=utf-8"}
// Render (XML) writes data with xml ContentType.
func (r XML) Render(w http.ResponseWriter) (err error) {
r.WriteContentType(w)
if err = xml.NewEncoder(w).Encode(r.Data); err != nil {
err = errors.WithStack(err)
}
return
}
// WriteContentType write xml ContentType.
func (r XML) WriteContentType(w http.ResponseWriter) {
writeContentType(w, xmlContentType)
}

View File

@ -0,0 +1,166 @@
package blademaster
import (
"regexp"
)
// IRouter http router framework interface.
type IRouter interface {
IRoutes
Group(string, ...HandlerFunc) *RouterGroup
}
// IRoutes http router interface.
type IRoutes interface {
UseFunc(...HandlerFunc) IRoutes
Use(...Handler) IRoutes
Handle(string, string, ...HandlerFunc) IRoutes
HEAD(string, ...HandlerFunc) IRoutes
GET(string, ...HandlerFunc) IRoutes
POST(string, ...HandlerFunc) IRoutes
PUT(string, ...HandlerFunc) IRoutes
DELETE(string, ...HandlerFunc) IRoutes
}
// RouterGroup is used internally to configure router, a RouterGroup is associated with a prefix
// and an array of handlers (middleware).
type RouterGroup struct {
Handlers []HandlerFunc
basePath string
engine *Engine
root bool
baseConfig *MethodConfig
}
var _ IRouter = &RouterGroup{}
// Use adds middleware to the group, see example code in doc.
func (group *RouterGroup) Use(middleware ...Handler) IRoutes {
for _, m := range middleware {
group.Handlers = append(group.Handlers, m.ServeHTTP)
}
return group.returnObj()
}
// UseFunc adds middleware to the group, see example code in doc.
func (group *RouterGroup) UseFunc(middleware ...HandlerFunc) IRoutes {
group.Handlers = append(group.Handlers, middleware...)
return group.returnObj()
}
// Group creates a new router group. You should add all the routes that have common middlwares or the same path prefix.
// For example, all the routes that use a common middlware for authorization could be grouped.
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup {
return &RouterGroup{
Handlers: group.combineHandlers(handlers),
basePath: group.calculateAbsolutePath(relativePath),
engine: group.engine,
root: false,
}
}
// SetMethodConfig is used to set config on specified method
func (group *RouterGroup) SetMethodConfig(config *MethodConfig) *RouterGroup {
group.baseConfig = config
return group
}
// BasePath router group base path.
func (group *RouterGroup) BasePath() string {
return group.basePath
}
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
absolutePath := group.calculateAbsolutePath(relativePath)
injections := group.injections(relativePath)
handlers = group.combineHandlers(injections, handlers)
group.engine.addRoute(httpMethod, absolutePath, handlers...)
if group.baseConfig != nil {
group.engine.SetMethodConfig(absolutePath, group.baseConfig)
}
return group.returnObj()
}
// Handle registers a new request handle and middleware with the given path and method.
// The last handler should be the real handler, the other ones should be middleware that can and should be shared among different routes.
// See the example code in doc.
//
// For HEAD, GET, POST, PUT, and DELETE requests the respective shortcut
// functions can be used.
//
// This function is intended for bulk loading and to allow the usage of less
// frequently used, non-standardized or custom methods (e.g. for internal
// communication with a proxy).
func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
if matches, err := regexp.MatchString("^[A-Z]+$", httpMethod); !matches || err != nil {
panic("http method " + httpMethod + " is not valid")
}
return group.handle(httpMethod, relativePath, handlers...)
}
// HEAD is a shortcut for router.Handle("HEAD", path, handle).
func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle("HEAD", relativePath, handlers...)
}
// GET is a shortcut for router.Handle("GET", path, handle).
func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle("GET", relativePath, handlers...)
}
// POST is a shortcut for router.Handle("POST", path, handle).
func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle("POST", relativePath, handlers...)
}
// PUT is a shortcut for router.Handle("PUT", path, handle).
func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle("PUT", relativePath, handlers...)
}
// DELETE is a shortcut for router.Handle("DELETE", path, handle).
func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle("DELETE", relativePath, handlers...)
}
func (group *RouterGroup) combineHandlers(handlerGroups ...[]HandlerFunc) []HandlerFunc {
finalSize := len(group.Handlers)
for _, handlers := range handlerGroups {
finalSize += len(handlers)
}
if finalSize >= int(_abortIndex) {
panic("too many handlers")
}
mergedHandlers := make([]HandlerFunc, finalSize)
copy(mergedHandlers, group.Handlers)
position := len(group.Handlers)
for _, handlers := range handlerGroups {
copy(mergedHandlers[position:], handlers)
position += len(handlers)
}
return mergedHandlers
}
func (group *RouterGroup) calculateAbsolutePath(relativePath string) string {
return joinPaths(group.basePath, relativePath)
}
func (group *RouterGroup) returnObj() IRoutes {
if group.root {
return group.engine
}
return group
}
// injections is
func (group *RouterGroup) injections(relativePath string) []HandlerFunc {
absPath := group.calculateAbsolutePath(relativePath)
for _, injection := range group.engine.injections {
if !injection.pattern.MatchString(absPath) {
continue
}
return injection.handlers
}
return nil
}

View File

@ -0,0 +1,445 @@
package blademaster
import (
"context"
"flag"
"fmt"
"net"
"net/http"
"os"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/bilibili/Kratos/pkg/conf/dsn"
"github.com/bilibili/Kratos/pkg/log"
"github.com/bilibili/Kratos/pkg/net/ip"
"github.com/bilibili/Kratos/pkg/net/metadata"
"github.com/bilibili/Kratos/pkg/stat"
xtime "github.com/bilibili/Kratos/pkg/time"
"github.com/pkg/errors"
)
const (
defaultMaxMemory = 32 << 20 // 32 MB
)
var (
_ IRouter = &Engine{}
stats = stat.HTTPServer
_httpDSN string
)
func init() {
addFlag(flag.CommandLine)
}
func addFlag(fs *flag.FlagSet) {
v := os.Getenv("HTTP")
if v == "" {
v = "tcp://0.0.0.0:8000/?timeout=1s"
}
fs.StringVar(&_httpDSN, "http", v, "listen http dsn, or use HTTP env variable.")
}
func parseDSN(rawdsn string) *ServerConfig {
conf := new(ServerConfig)
d, err := dsn.Parse(rawdsn)
if err != nil {
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn))
}
if _, err = d.Bind(conf); err != nil {
panic(errors.Wrapf(err, "blademaster: invalid dsn: %s", rawdsn))
}
return conf
}
// Handler responds to an HTTP request.
type Handler interface {
ServeHTTP(c *Context)
}
// HandlerFunc http request handler function.
type HandlerFunc func(*Context)
// ServeHTTP calls f(ctx).
func (f HandlerFunc) ServeHTTP(c *Context) {
f(c)
}
// ServerConfig is the bm server config model
type ServerConfig struct {
Network string `dsn:"network"`
// FIXME: rename to Address
Addr string `dsn:"address"`
Timeout xtime.Duration `dsn:"query.timeout"`
ReadTimeout xtime.Duration `dsn:"query.readTimeout"`
WriteTimeout xtime.Duration `dsn:"query.writeTimeout"`
}
// MethodConfig is
type MethodConfig struct {
Timeout xtime.Duration
}
// Start listen and serve bm engine by given DSN.
func (engine *Engine) Start() error {
conf := engine.conf
l, err := net.Listen(conf.Network, conf.Addr)
if err != nil {
errors.Wrapf(err, "blademaster: listen tcp: %s", conf.Addr)
return err
}
log.Info("blademaster: start http listen addr: %s", conf.Addr)
server := &http.Server{
ReadTimeout: time.Duration(conf.ReadTimeout),
WriteTimeout: time.Duration(conf.WriteTimeout),
}
go func() {
if err := engine.RunServer(server, l); err != nil {
if errors.Cause(err) == http.ErrServerClosed {
log.Info("blademaster: server closed")
return
}
panic(errors.Wrapf(err, "blademaster: engine.ListenServer(%+v, %+v)", server, l))
}
}()
return nil
}
// Engine is the framework's instance, it contains the muxer, middleware and configuration settings.
// Create an instance of Engine, by using New() or Default()
type Engine struct {
RouterGroup
lock sync.RWMutex
conf *ServerConfig
address string
mux *http.ServeMux // http mux router
server atomic.Value // store *http.Server
metastore map[string]map[string]interface{} // metastore is the path as key and the metadata of this path as value, it export via /metadata
pcLock sync.RWMutex
methodConfigs map[string]*MethodConfig
injections []injection
}
type injection struct {
pattern *regexp.Regexp
handlers []HandlerFunc
}
// New returns a new blank Engine instance without any middleware attached.
//
// Deprecated: please use NewServer.
func New() *Engine {
engine := &Engine{
RouterGroup: RouterGroup{
Handlers: nil,
basePath: "/",
root: true,
},
address: ip.InternalIP(),
conf: &ServerConfig{
Timeout: xtime.Duration(time.Second),
},
mux: http.NewServeMux(),
metastore: make(map[string]map[string]interface{}),
methodConfigs: make(map[string]*MethodConfig),
injections: make([]injection, 0),
}
engine.RouterGroup.engine = engine
// NOTE add prometheus monitor location
engine.addRoute("GET", "/metrics", monitor())
engine.addRoute("GET", "/metadata", engine.metadata())
startPerf()
return engine
}
// NewServer returns a new blank Engine instance without any middleware attached.
func NewServer(conf *ServerConfig) *Engine {
if conf == nil {
if !flag.Parsed() {
fmt.Fprint(os.Stderr, "[blademaster] please call flag.Parse() before Init warden server, some configure may not effect.\n")
}
conf = parseDSN(_httpDSN)
} else {
fmt.Fprintf(os.Stderr, "[blademaster] config will be deprecated, argument will be ignored. please use -http flag or HTTP env to configure http server.\n")
}
engine := &Engine{
RouterGroup: RouterGroup{
Handlers: nil,
basePath: "/",
root: true,
},
address: ip.InternalIP(),
mux: http.NewServeMux(),
metastore: make(map[string]map[string]interface{}),
methodConfigs: make(map[string]*MethodConfig),
}
if err := engine.SetConfig(conf); err != nil {
panic(err)
}
engine.RouterGroup.engine = engine
// NOTE add prometheus monitor location
engine.addRoute("GET", "/metrics", monitor())
engine.addRoute("GET", "/metadata", engine.metadata())
startPerf()
return engine
}
// SetMethodConfig is used to set config on specified path
func (engine *Engine) SetMethodConfig(path string, mc *MethodConfig) {
engine.pcLock.Lock()
engine.methodConfigs[path] = mc
engine.pcLock.Unlock()
}
// DefaultServer returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
func DefaultServer(conf *ServerConfig) *Engine {
engine := NewServer(conf)
engine.Use(Recovery(), Trace(), Logger())
return engine
}
// Default returns an Engine instance with the Recovery, Logger and CSRF middleware already attached.
//
// Deprecated: please use DefaultServer.
func Default() *Engine {
engine := New()
engine.Use(Recovery(), Trace(), Logger())
return engine
}
func (engine *Engine) addRoute(method, path string, handlers ...HandlerFunc) {
if path[0] != '/' {
panic("blademaster: path must begin with '/'")
}
if method == "" {
panic("blademaster: HTTP method can not be empty")
}
if len(handlers) == 0 {
panic("blademaster: there must be at least one handler")
}
if _, ok := engine.metastore[path]; !ok {
engine.metastore[path] = make(map[string]interface{})
}
engine.metastore[path]["method"] = method
engine.mux.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) {
c := &Context{
Context: nil,
engine: engine,
index: -1,
handlers: nil,
Keys: nil,
method: "",
Error: nil,
}
c.Request = req
c.Writer = w
c.handlers = handlers
c.method = method
engine.handleContext(c)
})
}
// SetConfig is used to set the engine configuration.
// Only the valid config will be loaded.
func (engine *Engine) SetConfig(conf *ServerConfig) (err error) {
if conf.Timeout <= 0 {
return errors.New("blademaster: config timeout must greater than 0")
}
if conf.Network == "" {
conf.Network = "tcp"
}
engine.lock.Lock()
engine.conf = conf
engine.lock.Unlock()
return
}
func (engine *Engine) methodConfig(path string) *MethodConfig {
engine.pcLock.RLock()
mc := engine.methodConfigs[path]
engine.pcLock.RUnlock()
return mc
}
func (engine *Engine) handleContext(c *Context) {
var cancel func()
req := c.Request
ctype := req.Header.Get("Content-Type")
switch {
case strings.Contains(ctype, "multipart/form-data"):
req.ParseMultipartForm(defaultMaxMemory)
default:
req.ParseForm()
}
// get derived timeout from http request header,
// compare with the engine configured,
// and use the minimum one
engine.lock.RLock()
tm := time.Duration(engine.conf.Timeout)
engine.lock.RUnlock()
// the method config is preferred
if pc := engine.methodConfig(c.Request.URL.Path); pc != nil {
tm = time.Duration(pc.Timeout)
}
if ctm := timeout(req); ctm > 0 && tm > ctm {
tm = ctm
}
md := metadata.MD{
metadata.Color: color(req),
metadata.RemoteIP: remoteIP(req),
metadata.RemotePort: remotePort(req),
metadata.Caller: caller(req),
metadata.Mirror: mirror(req),
}
ctx := metadata.NewContext(context.Background(), md)
if tm > 0 {
c.Context, cancel = context.WithTimeout(ctx, tm)
} else {
c.Context, cancel = context.WithCancel(ctx)
}
defer cancel()
c.Next()
}
// Router return a http.Handler for using http.ListenAndServe() directly.
func (engine *Engine) Router() http.Handler {
return engine.mux
}
// Server is used to load stored http server.
func (engine *Engine) Server() *http.Server {
s, ok := engine.server.Load().(*http.Server)
if !ok {
return nil
}
return s
}
// Shutdown the http server without interrupting active connections.
func (engine *Engine) Shutdown(ctx context.Context) error {
server := engine.Server()
if server == nil {
return errors.New("blademaster: no server")
}
return errors.WithStack(server.Shutdown(ctx))
}
// UseFunc attachs a global middleware to the router. ie. the middleware attached though UseFunc() will be
// included in the handlers chain for every single request. Even 404, 405, static files...
// For example, this is the right place for a logger or error management middleware.
func (engine *Engine) UseFunc(middleware ...HandlerFunc) IRoutes {
engine.RouterGroup.UseFunc(middleware...)
return engine
}
// Use attachs a global middleware to the router. ie. the middleware attached though Use() will be
// included in the handlers chain for every single request. Even 404, 405, static files...
// For example, this is the right place for a logger or error management middleware.
func (engine *Engine) Use(middleware ...Handler) IRoutes {
engine.RouterGroup.Use(middleware...)
return engine
}
// Ping is used to set the general HTTP ping handler.
func (engine *Engine) Ping(handler HandlerFunc) {
engine.GET("/monitor/ping", handler)
}
// Register is used to export metadata to discovery.
func (engine *Engine) Register(handler HandlerFunc) {
engine.GET("/register", handler)
}
// Run attaches the router to a http.Server and starts listening and serving HTTP requests.
// It is a shortcut for http.ListenAndServe(addr, router)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) Run(addr ...string) (err error) {
address := resolveAddress(addr)
server := &http.Server{
Addr: address,
Handler: engine.mux,
}
engine.server.Store(server)
if err = server.ListenAndServe(); err != nil {
err = errors.Wrapf(err, "addrs: %v", addr)
}
return
}
// RunTLS attaches the router to a http.Server and starts listening and serving HTTPS (secure) requests.
// It is a shortcut for http.ListenAndServeTLS(addr, certFile, keyFile, router)
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) RunTLS(addr, certFile, keyFile string) (err error) {
server := &http.Server{
Addr: addr,
Handler: engine.mux,
}
engine.server.Store(server)
if err = server.ListenAndServeTLS(certFile, keyFile); err != nil {
err = errors.Wrapf(err, "tls: %s/%s:%s", addr, certFile, keyFile)
}
return
}
// RunUnix attaches the router to a http.Server and starts listening and serving HTTP requests
// through the specified unix socket (ie. a file).
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) RunUnix(file string) (err error) {
os.Remove(file)
listener, err := net.Listen("unix", file)
if err != nil {
err = errors.Wrapf(err, "unix: %s", file)
return
}
defer listener.Close()
server := &http.Server{
Handler: engine.mux,
}
engine.server.Store(server)
if err = server.Serve(listener); err != nil {
err = errors.Wrapf(err, "unix: %s", file)
}
return
}
// RunServer will serve and start listening HTTP requests by given server and listener.
// Note: this method will block the calling goroutine indefinitely unless an error happens.
func (engine *Engine) RunServer(server *http.Server, l net.Listener) (err error) {
server.Handler = engine.mux
engine.server.Store(server)
if err = server.Serve(l); err != nil {
err = errors.Wrapf(err, "listen server: %+v/%+v", server, l)
return
}
return
}
func (engine *Engine) metadata() HandlerFunc {
return func(c *Context) {
c.JSON(engine.metastore, nil)
}
}
// Inject is
func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) {
engine.injections = append(engine.injections, injection{
pattern: regexp.MustCompile(pattern),
handlers: handlers,
})
}

View File

@ -4,12 +4,43 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"strconv"
"github.com/bilibili/Kratos/pkg/net/metadata"
"github.com/bilibili/Kratos/pkg/net/trace" "github.com/bilibili/Kratos/pkg/net/trace"
) )
const _defaultComponentName = "net/http" const _defaultComponentName = "net/http"
// Trace is trace middleware
func Trace() HandlerFunc {
return func(c *Context) {
// handle http request
// get derived trace from http request header
t, err := trace.Extract(trace.HTTPFormat, c.Request.Header)
if err != nil {
var opts []trace.Option
if ok, _ := strconv.ParseBool(trace.KratosTraceDebug); ok {
opts = append(opts, trace.EnableDebug())
}
t = trace.New(c.Request.URL.Path, opts...)
}
t.SetTitle(c.Request.URL.Path)
t.SetTag(trace.String(trace.TagComponent, _defaultComponentName))
t.SetTag(trace.String(trace.TagHTTPMethod, c.Request.Method))
t.SetTag(trace.String(trace.TagHTTPURL, c.Request.URL.String()))
t.SetTag(trace.String(trace.TagSpanKind, "server"))
// business tag
t.SetTag(trace.String("caller", metadata.String(c.Context, metadata.Caller)))
// export trace id to user.
// TODO(zhoujiahui): trace package should be updated
// c.Writer.Header().Set(trace.KratosTraceID, t.TraceID())
c.Context = trace.NewContext(c.Context, t)
c.Next()
t.Finish(&c.Error)
}
}
type closeTracker struct { type closeTracker struct {
io.ReadCloser io.ReadCloser
tr trace.Trace tr trace.Trace

View File

@ -0,0 +1,42 @@
package blademaster
import (
"os"
"path"
)
func lastChar(str string) uint8 {
if str == "" {
panic("The length of the string can't be 0")
}
return str[len(str)-1]
}
func joinPaths(absolutePath, relativePath string) string {
if relativePath == "" {
return absolutePath
}
finalPath := path.Join(absolutePath, relativePath)
appendSlash := lastChar(relativePath) == '/' && lastChar(finalPath) != '/'
if appendSlash {
return finalPath + "/"
}
return finalPath
}
func resolveAddress(addr []string) string {
switch len(addr) {
case 0:
if port := os.Getenv("PORT"); port != "" {
//debugPrint("Environment variable PORT=\"%s\"", port)
return ":" + port
}
//debugPrint("Environment variable PORT is undefined. Using port :8080 by default")
return ":8080"
case 1:
return addr[0]
default:
panic("too much parameters")
}
}