1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-24 03:46:37 +02:00

【#562】fix context Keys map concurrent issue (#561)

* 1、Fix context Keys map concurrent issue
2、add RemoteIP func to get client ip

* context mutex default to zero-value

Co-authored-by: John Sun <sunqiang@styd.cn>
This commit is contained in:
Klaus 2020-04-30 17:13:23 +08:00 committed by GitHub
parent d99f595ea1
commit d29dfdfd67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 12 deletions

View File

@ -5,8 +5,12 @@ import (
"math"
"net/http"
"strconv"
"strings"
"sync"
"text/template"
"github.com/go-kratos/kratos/pkg/net/metadata"
"github.com/go-kratos/kratos/pkg/ecode"
"github.com/go-kratos/kratos/pkg/net/http/blademaster/binding"
"github.com/go-kratos/kratos/pkg/net/http/blademaster/render"
@ -40,6 +44,8 @@ type Context struct {
// Keys is a key/value pair exclusively for the context of each request.
Keys map[string]interface{}
// This mutex protect Keys map
keysMutex sync.RWMutex
Error error
@ -51,6 +57,20 @@ type Context struct {
Params Params
}
/************************************/
/********** CONTEXT CREATION ********/
/************************************/
func (c *Context) reset() {
c.Context = nil
c.index = -1
c.handlers = nil
c.Keys = nil
c.Error = nil
c.method = ""
c.RoutePath = ""
c.Params = c.Params[0:0]
}
/************************************/
/*********** FLOW CONTROL ***********/
/************************************/
@ -93,16 +113,76 @@ func (c *Context) IsAborted() bool {
// Set is used to store a new key/value pair exclusively for this context.
// It also lazy initializes c.Keys if it was not used previously.
func (c *Context) Set(key string, value interface{}) {
c.keysMutex.Lock()
if c.Keys == nil {
c.Keys = make(map[string]interface{})
}
c.Keys[key] = value
c.keysMutex.Unlock()
}
// Get returns the value for the given key, ie: (value, true).
// If the value does not exists it returns (nil, false)
func (c *Context) Get(key string) (value interface{}, exists bool) {
c.keysMutex.RLock()
value, exists = c.Keys[key]
c.keysMutex.RUnlock()
return
}
// GetString returns the value associated with the key as a string.
func (c *Context) GetString(key string) (s string) {
if val, ok := c.Get(key); ok && val != nil {
s, _ = val.(string)
}
return
}
// GetBool returns the value associated with the key as a boolean.
func (c *Context) GetBool(key string) (b bool) {
if val, ok := c.Get(key); ok && val != nil {
b, _ = val.(bool)
}
return
}
// GetInt returns the value associated with the key as an integer.
func (c *Context) GetInt(key string) (i int) {
if val, ok := c.Get(key); ok && val != nil {
i, _ = val.(int)
}
return
}
// GetUint returns the value associated with the key as an unsigned integer.
func (c *Context) GetUint(key string) (ui uint) {
if val, ok := c.Get(key); ok && val != nil {
ui, _ = val.(uint)
}
return
}
// GetInt64 returns the value associated with the key as an integer.
func (c *Context) GetInt64(key string) (i64 int64) {
if val, ok := c.Get(key); ok && val != nil {
i64, _ = val.(int64)
}
return
}
// GetUint64 returns the value associated with the key as an unsigned integer.
func (c *Context) GetUint64(key string) (ui64 uint64) {
if val, ok := c.Get(key); ok && val != nil {
ui64, _ = val.(uint64)
}
return
}
// GetFloat64 returns the value associated with the key as a float64.
func (c *Context) GetFloat64(key string) (f64 float64) {
if val, ok := c.Get(key); ok && val != nil {
f64, _ = val.(float64)
}
return
}
@ -307,3 +387,22 @@ func writeStatusCode(w http.ResponseWriter, ecode int) {
header := w.Header()
header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10))
}
// RemoteIP implements a best effort algorithm to return the real client IP, it parses
// X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
// Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP.
// Notice: metadata.RemoteIP take precedence over X-Forwarded-For and X-Real-Ip
func (c *Context) RemoteIP() (remoteIP string) {
remoteIP = metadata.String(c, metadata.RemoteIP)
if remoteIP != "" {
return
}
remoteIP = c.Request.Header.Get("X-Forwarded-For")
remoteIP = strings.TrimSpace(strings.Split(remoteIP, ",")[0])
if remoteIP == "" {
remoteIP = strings.TrimSpace(c.Request.Header.Get("X-Real-Ip"))
}
return
}

View File

@ -15,7 +15,6 @@ func Logger() HandlerFunc {
const noUser = "no_user"
return func(c *Context) {
now := time.Now()
ip := metadata.String(c, metadata.RemoteIP)
req := c.Request
path := req.URL.Path
params := req.Form
@ -55,7 +54,7 @@ func Logger() HandlerFunc {
}
lf(c,
log.KVString("method", req.Method),
log.KVString("ip", ip),
log.KVString("ip", c.RemoteIP()),
log.KVString("user", caller),
log.KVString("path", path),
log.KVString("params", params.Encode()),

View File

@ -151,6 +151,8 @@ type Engine struct {
allNoMethod []HandlerFunc
noRoute []HandlerFunc
noMethod []HandlerFunc
pool sync.Pool
}
type injection struct {
@ -182,6 +184,9 @@ func NewServer(conf *ServerConfig) *Engine {
if err := engine.SetConfig(conf); err != nil {
panic(err)
}
engine.pool.New = func() interface{} {
return engine.newContext()
}
engine.RouterGroup.engine = engine
// NOTE add prometheus monitor location
engine.addRoute("GET", "/metrics", monitor())
@ -477,20 +482,18 @@ func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) {
// ServeHTTP conforms to the http.Handler interface.
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c := &Context{
Context: nil,
engine: engine,
index: -1,
handlers: nil,
Keys: nil,
method: "",
Error: nil,
}
c := engine.pool.Get().(*Context)
c.Request = req
c.Writer = w
c.reset()
engine.handleContext(c)
engine.pool.Put(c)
}
//newContext for sync.pool
func (engine *Engine) newContext() *Context {
return &Context{engine: engine}
}
// NoRoute adds handlers for NoRoute. It return a 404 code by default.