1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-24 03:46:37 +02:00
2019-04-16 16:50:34 +08:00

154 lines
3.5 KiB
Go

package auth
import (
"github.com/bilibili/kratos/pkg/ecode"
bm "github.com/bilibili/kratos/pkg/net/http/blademaster"
"github.com/bilibili/kratos/pkg/net/metadata"
)
// Config is the identify config model.
type Config struct {
// csrf switch.
DisableCSRF bool
}
// Auth is the authorization middleware
type Auth struct {
conf *Config
}
// authFunc will return mid and error by given context
type authFunc func(*bm.Context) (int64, error)
var _defaultConf = &Config{
DisableCSRF: false,
}
// New is used to create an authorization middleware
func New(conf *Config) *Auth {
if conf == nil {
conf = _defaultConf
}
auth := &Auth{
conf: conf,
}
return auth
}
// User is used to mark path as access required.
// If `access_token` is exist in request form, it will using mobile access policy.
// Otherwise to web access policy.
func (a *Auth) User(ctx *bm.Context) {
req := ctx.Request
if req.Form.Get("access_token") == "" {
a.UserWeb(ctx)
return
}
a.UserMobile(ctx)
}
// UserWeb is used to mark path as web access required.
func (a *Auth) UserWeb(ctx *bm.Context) {
a.midAuth(ctx, a.authCookie)
}
// UserMobile is used to mark path as mobile access required.
func (a *Auth) UserMobile(ctx *bm.Context) {
a.midAuth(ctx, a.authToken)
}
// Guest is used to mark path as guest policy.
// If `access_token` is exist in request form, it will using mobile access policy.
// Otherwise to web access policy.
func (a *Auth) Guest(ctx *bm.Context) {
req := ctx.Request
if req.Form.Get("access_token") == "" {
a.GuestWeb(ctx)
return
}
a.GuestMobile(ctx)
}
// GuestWeb is used to mark path as web guest policy.
func (a *Auth) GuestWeb(ctx *bm.Context) {
a.guestAuth(ctx, a.authCookie)
}
// GuestMobile is used to mark path as mobile guest policy.
func (a *Auth) GuestMobile(ctx *bm.Context) {
a.guestAuth(ctx, a.authToken)
}
// authToken is used to authorize request by token
func (a *Auth) authToken(ctx *bm.Context) (int64, error) {
req := ctx.Request
key := req.Form.Get("access_token")
if key == "" {
return 0, ecode.Unauthorized
}
// NOTE: 请求登录鉴权服务接口,拿到对应的用户id
var mid int64
// TODO: get mid from some code
return mid, nil
}
// authCookie is used to authorize request by cookie
func (a *Auth) authCookie(ctx *bm.Context) (int64, error) {
req := ctx.Request
session, _ := req.Cookie("SESSION")
if session == nil {
return 0, ecode.Unauthorized
}
// NOTE: 请求登录鉴权服务接口,拿到对应的用户id
var mid int64
// TODO: get mid from some code
// check csrf
clientCsrf := req.FormValue("csrf")
if a.conf != nil && !a.conf.DisableCSRF && req.Method == "POST" {
// NOTE: 如果开启了CSRF认证,请从CSRF服务获取该用户关联的csrf
var csrf string // TODO: get csrf from some code
if clientCsrf != csrf {
return 0, ecode.Unauthorized
}
}
return mid, nil
}
func (a *Auth) midAuth(ctx *bm.Context, auth authFunc) {
mid, err := auth(ctx)
if err != nil {
ctx.JSON(nil, err)
ctx.Abort()
return
}
setMid(ctx, mid)
}
func (a *Auth) guestAuth(ctx *bm.Context, auth authFunc) {
mid, err := auth(ctx)
// no error happened and mid is valid
if err == nil && mid > 0 {
setMid(ctx, mid)
return
}
ec := ecode.Cause(err)
if ecode.Equal(ec, ecode.Unauthorized) {
ctx.JSON(nil, ec)
ctx.Abort()
return
}
}
// set mid into context
// NOTE: This method is not thread safe.
func setMid(ctx *bm.Context, mid int64) {
ctx.Set(metadata.Mid, mid)
if md, ok := metadata.FromContext(ctx); ok {
md[metadata.Mid] = mid
return
}
}