mirror of
https://github.com/go-kratos/kratos.git
synced 2025-01-24 03:46:37 +02:00
【#562】fix context Keys map concurrent issue (#561)
* 1、Fix context Keys map concurrent issue 2、add RemoteIP func to get client ip * context mutex default to zero-value Co-authored-by: John Sun <sunqiang@styd.cn>
This commit is contained in:
parent
d99f595ea1
commit
d29dfdfd67
@ -5,8 +5,12 @@ import (
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
|
||||
"github.com/go-kratos/kratos/pkg/net/metadata"
|
||||
|
||||
"github.com/go-kratos/kratos/pkg/ecode"
|
||||
"github.com/go-kratos/kratos/pkg/net/http/blademaster/binding"
|
||||
"github.com/go-kratos/kratos/pkg/net/http/blademaster/render"
|
||||
@ -40,6 +44,8 @@ type Context struct {
|
||||
|
||||
// Keys is a key/value pair exclusively for the context of each request.
|
||||
Keys map[string]interface{}
|
||||
// This mutex protect Keys map
|
||||
keysMutex sync.RWMutex
|
||||
|
||||
Error error
|
||||
|
||||
@ -51,6 +57,20 @@ type Context struct {
|
||||
Params Params
|
||||
}
|
||||
|
||||
/************************************/
|
||||
/********** CONTEXT CREATION ********/
|
||||
/************************************/
|
||||
func (c *Context) reset() {
|
||||
c.Context = nil
|
||||
c.index = -1
|
||||
c.handlers = nil
|
||||
c.Keys = nil
|
||||
c.Error = nil
|
||||
c.method = ""
|
||||
c.RoutePath = ""
|
||||
c.Params = c.Params[0:0]
|
||||
}
|
||||
|
||||
/************************************/
|
||||
/*********** FLOW CONTROL ***********/
|
||||
/************************************/
|
||||
@ -93,16 +113,76 @@ func (c *Context) IsAborted() bool {
|
||||
// Set is used to store a new key/value pair exclusively for this context.
|
||||
// It also lazy initializes c.Keys if it was not used previously.
|
||||
func (c *Context) Set(key string, value interface{}) {
|
||||
c.keysMutex.Lock()
|
||||
if c.Keys == nil {
|
||||
c.Keys = make(map[string]interface{})
|
||||
}
|
||||
c.Keys[key] = value
|
||||
c.keysMutex.Unlock()
|
||||
}
|
||||
|
||||
// Get returns the value for the given key, ie: (value, true).
|
||||
// If the value does not exists it returns (nil, false)
|
||||
func (c *Context) Get(key string) (value interface{}, exists bool) {
|
||||
c.keysMutex.RLock()
|
||||
value, exists = c.Keys[key]
|
||||
c.keysMutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// GetString returns the value associated with the key as a string.
|
||||
func (c *Context) GetString(key string) (s string) {
|
||||
if val, ok := c.Get(key); ok && val != nil {
|
||||
s, _ = val.(string)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetBool returns the value associated with the key as a boolean.
|
||||
func (c *Context) GetBool(key string) (b bool) {
|
||||
if val, ok := c.Get(key); ok && val != nil {
|
||||
b, _ = val.(bool)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetInt returns the value associated with the key as an integer.
|
||||
func (c *Context) GetInt(key string) (i int) {
|
||||
if val, ok := c.Get(key); ok && val != nil {
|
||||
i, _ = val.(int)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetUint returns the value associated with the key as an unsigned integer.
|
||||
func (c *Context) GetUint(key string) (ui uint) {
|
||||
if val, ok := c.Get(key); ok && val != nil {
|
||||
ui, _ = val.(uint)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetInt64 returns the value associated with the key as an integer.
|
||||
func (c *Context) GetInt64(key string) (i64 int64) {
|
||||
if val, ok := c.Get(key); ok && val != nil {
|
||||
i64, _ = val.(int64)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetUint64 returns the value associated with the key as an unsigned integer.
|
||||
func (c *Context) GetUint64(key string) (ui64 uint64) {
|
||||
if val, ok := c.Get(key); ok && val != nil {
|
||||
ui64, _ = val.(uint64)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetFloat64 returns the value associated with the key as a float64.
|
||||
func (c *Context) GetFloat64(key string) (f64 float64) {
|
||||
if val, ok := c.Get(key); ok && val != nil {
|
||||
f64, _ = val.(float64)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -307,3 +387,22 @@ func writeStatusCode(w http.ResponseWriter, ecode int) {
|
||||
header := w.Header()
|
||||
header.Set("kratos-status-code", strconv.FormatInt(int64(ecode), 10))
|
||||
}
|
||||
|
||||
// RemoteIP implements a best effort algorithm to return the real client IP, it parses
|
||||
// X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy.
|
||||
// Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP.
|
||||
// Notice: metadata.RemoteIP take precedence over X-Forwarded-For and X-Real-Ip
|
||||
func (c *Context) RemoteIP() (remoteIP string) {
|
||||
remoteIP = metadata.String(c, metadata.RemoteIP)
|
||||
if remoteIP != "" {
|
||||
return
|
||||
}
|
||||
|
||||
remoteIP = c.Request.Header.Get("X-Forwarded-For")
|
||||
remoteIP = strings.TrimSpace(strings.Split(remoteIP, ",")[0])
|
||||
if remoteIP == "" {
|
||||
remoteIP = strings.TrimSpace(c.Request.Header.Get("X-Real-Ip"))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -15,7 +15,6 @@ func Logger() HandlerFunc {
|
||||
const noUser = "no_user"
|
||||
return func(c *Context) {
|
||||
now := time.Now()
|
||||
ip := metadata.String(c, metadata.RemoteIP)
|
||||
req := c.Request
|
||||
path := req.URL.Path
|
||||
params := req.Form
|
||||
@ -55,7 +54,7 @@ func Logger() HandlerFunc {
|
||||
}
|
||||
lf(c,
|
||||
log.KVString("method", req.Method),
|
||||
log.KVString("ip", ip),
|
||||
log.KVString("ip", c.RemoteIP()),
|
||||
log.KVString("user", caller),
|
||||
log.KVString("path", path),
|
||||
log.KVString("params", params.Encode()),
|
||||
|
@ -151,6 +151,8 @@ type Engine struct {
|
||||
allNoMethod []HandlerFunc
|
||||
noRoute []HandlerFunc
|
||||
noMethod []HandlerFunc
|
||||
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
type injection struct {
|
||||
@ -182,6 +184,9 @@ func NewServer(conf *ServerConfig) *Engine {
|
||||
if err := engine.SetConfig(conf); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
engine.pool.New = func() interface{} {
|
||||
return engine.newContext()
|
||||
}
|
||||
engine.RouterGroup.engine = engine
|
||||
// NOTE add prometheus monitor location
|
||||
engine.addRoute("GET", "/metrics", monitor())
|
||||
@ -477,20 +482,18 @@ func (engine *Engine) Inject(pattern string, handlers ...HandlerFunc) {
|
||||
|
||||
// ServeHTTP conforms to the http.Handler interface.
|
||||
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
c := &Context{
|
||||
Context: nil,
|
||||
engine: engine,
|
||||
index: -1,
|
||||
handlers: nil,
|
||||
Keys: nil,
|
||||
method: "",
|
||||
Error: nil,
|
||||
}
|
||||
|
||||
c := engine.pool.Get().(*Context)
|
||||
c.Request = req
|
||||
c.Writer = w
|
||||
c.reset()
|
||||
|
||||
engine.handleContext(c)
|
||||
engine.pool.Put(c)
|
||||
}
|
||||
|
||||
//newContext for sync.pool
|
||||
func (engine *Engine) newContext() *Context {
|
||||
return &Context{engine: engine}
|
||||
}
|
||||
|
||||
// NoRoute adds handlers for NoRoute. It return a 404 code by default.
|
||||
|
Loading…
x
Reference in New Issue
Block a user