diff --git a/auth/options.go b/auth/options.go index 3cd1bc69..cb395d04 100644 --- a/auth/options.go +++ b/auth/options.go @@ -79,6 +79,13 @@ func Credentials(id, secret string) Option { } } +// ClientToken sets the auth token to use when making requests +func ClientToken(token *Token) Option { + return func(o *Options) { + o.Token = token + } +} + // Provider set the auth provider func Provider(p provider.Provider) Option { return func(o *Options) { diff --git a/auth/service/service.go b/auth/service/service.go index 67c4d4ca..c0f7e7dd 100644 --- a/auth/service/service.go +++ b/auth/service/service.go @@ -70,35 +70,6 @@ func (s *svc) Init(opts ...auth.Option) { s.loadRules() } }() - - // we have client credentials and must load a new token - // periodically - if len(s.options.ID) > 0 || len(s.options.Secret) > 0 { - // get a token immediately - s.refreshToken() - - go func() { - tokenTimer := time.NewTicker(time.Minute) - - for { - <-tokenTimer.C - - // Do not get a new token if the current one has more than three - // minutes remaining. We do 3 minutes to allow multiple retires in - // the case one request fails - t := s.Options().Token - if t != nil && t.Expiry.Unix() > time.Now().Add(time.Minute*3).Unix() { - continue - } - - // jitter for up to 5 seconds, this stops - // all the services calling the auth service - // at the exact same time - time.Sleep(jitter.Do(time.Second * 5)) - s.refreshToken() - } - }() - } } func (s *svc) Options() auth.Options { @@ -313,33 +284,6 @@ func (s *svc) loadRules() { s.rules = rsp.Rules } -// refreshToken generates a new token for the service to use when making calls -func (s *svc) refreshToken() { - req := &pb.TokenRequest{ - TokenExpiry: int64((time.Minute * 15).Seconds()), - } - - if s.Options().Token == nil { - // we do not have a token, use the credentials to get one - req.Id = s.Options().ID - req.Secret = s.Options().Secret - } else { - // we have a token, refresh it - req.RefreshToken = s.Options().Token.RefreshToken - } - - rsp, err := s.auth.Token(context.TODO(), req) - s.Lock() - defer s.Unlock() - - if err != nil { - log.Errorf("Error generating token: %v", err) - return - } - - s.options.Token = serializeToken(rsp.Token) -} - func serializeToken(t *pb.Token) *auth.Token { return &auth.Token{ AccessToken: t.AccessToken, diff --git a/client/grpc/grpc.go b/client/grpc/grpc.go index 9445eabe..e25f3dd5 100644 --- a/client/grpc/grpc.go +++ b/client/grpc/grpc.go @@ -10,7 +10,6 @@ import ( "sync/atomic" "time" - "github.com/micro/go-micro/v2/auth" "github.com/micro/go-micro/v2/broker" "github.com/micro/go-micro/v2/client" "github.com/micro/go-micro/v2/client/selector" @@ -18,7 +17,6 @@ import ( "github.com/micro/go-micro/v2/errors" "github.com/micro/go-micro/v2/metadata" "github.com/micro/go-micro/v2/registry" - "github.com/micro/go-micro/v2/util/config" pnet "github.com/micro/go-micro/v2/util/net" "google.golang.org/grpc" @@ -117,13 +115,6 @@ func (g *grpcClient) call(ctx context.Context, node *registry.Node, req client.R // set the content type for the request header["x-content-type"] = req.ContentType() - // set the authorization header - if opts.ServiceToken || len(header["authorization"]) == 0 { - if h := g.authorizationHeader(); len(h) > 0 { - header["authorization"] = h - } - } - md := gmetadata.New(header) ctx = gmetadata.NewOutgoingContext(ctx, md) @@ -202,13 +193,6 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client // set the content type for the request header["x-content-type"] = req.ContentType() - // set the authorization header - if opts.ServiceToken || len(header["authorization"]) == 0 { - if h := g.authorizationHeader(); len(h) > 0 { - header["authorization"] = h - } - } - md := gmetadata.New(header) ctx = gmetadata.NewOutgoingContext(ctx, md) @@ -295,26 +279,6 @@ func (g *grpcClient) stream(ctx context.Context, node *registry.Node, req client }, nil } -func (g *grpcClient) authorizationHeader() string { - // if the caller specifies using service token or no token - // was passed with the request, set the service token - var srvToken string - if g.opts.Auth != nil && g.opts.Auth.Options().Token != nil { - srvToken = g.opts.Auth.Options().Token.AccessToken - } - if len(srvToken) > 0 { - return auth.BearerScheme + srvToken - } - - // fall back to using the authorization token set in config, - // this enables the CLI to provide a token - if token, err := config.Get("micro", "auth", "token"); err == nil && len(token) > 0 { - return auth.BearerScheme + token - } - - return "" -} - func (g *grpcClient) poolMaxStreams() int { if g.opts.Context == nil { return DefaultPoolMaxStreams diff --git a/service.go b/service.go index a49c696f..f7941e30 100644 --- a/service.go +++ b/service.go @@ -39,8 +39,9 @@ func newService(opts ...Option) Service { authFn := func() auth.Auth { return options.Server.Options().Auth } // wrap client to inject From-Service header on any calls - options.Client = wrapper.FromService(serviceName, options.Client, authFn) + options.Client = wrapper.FromService(serviceName, options.Client) options.Client = wrapper.TraceCall(serviceName, trace.DefaultTracer, options.Client) + options.Client = wrapper.AuthClient(serviceName, options.Server.Options().Id, authFn, options.Client) // wrap the server to provide handler stats options.Server.Init( diff --git a/util/wrapper/wrapper.go b/util/wrapper/wrapper.go index bfd1ce9f..ee569ef2 100644 --- a/util/wrapper/wrapper.go +++ b/util/wrapper/wrapper.go @@ -2,7 +2,9 @@ package wrapper import ( "context" + "fmt" "strings" + "time" "github.com/micro/go-micro/v2/auth" "github.com/micro/go-micro/v2/client" @@ -11,68 +13,44 @@ import ( "github.com/micro/go-micro/v2/errors" "github.com/micro/go-micro/v2/metadata" "github.com/micro/go-micro/v2/server" + "github.com/micro/go-micro/v2/util/config" ) -type clientWrapper struct { +type fromServiceWrapper struct { client.Client - // Auth interface - auth func() auth.Auth // headers to inject headers metadata.Metadata } -type traceWrapper struct { - client.Client - - name string - trace trace.Tracer -} - var ( HeaderPrefix = "Micro-" ) -func (c *clientWrapper) setHeaders(ctx context.Context) context.Context { +func (f *fromServiceWrapper) setHeaders(ctx context.Context) context.Context { // don't overwrite keys - return metadata.MergeContext(ctx, c.headers, false) + return metadata.MergeContext(ctx, f.headers, false) } -func (c *clientWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { - ctx = c.setHeaders(ctx) - return c.Client.Call(ctx, req, rsp, opts...) +func (f *fromServiceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { + ctx = f.setHeaders(ctx) + return f.Client.Call(ctx, req, rsp, opts...) } -func (c *clientWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { - ctx = c.setHeaders(ctx) - return c.Client.Stream(ctx, req, opts...) +func (f *fromServiceWrapper) Stream(ctx context.Context, req client.Request, opts ...client.CallOption) (client.Stream, error) { + ctx = f.setHeaders(ctx) + return f.Client.Stream(ctx, req, opts...) } -func (c *clientWrapper) Publish(ctx context.Context, p client.Message, opts ...client.PublishOption) error { - ctx = c.setHeaders(ctx) - return c.Client.Publish(ctx, p, opts...) -} - -func (c *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { - newCtx, s := c.trace.Start(ctx, req.Service()+"."+req.Endpoint()) - - s.Type = trace.SpanTypeRequestOutbound - err := c.Client.Call(newCtx, req, rsp, opts...) - if err != nil { - s.Metadata["error"] = err.Error() - } - - // finish the trace - c.trace.Finish(s) - - return err +func (f *fromServiceWrapper) Publish(ctx context.Context, p client.Message, opts ...client.PublishOption) error { + ctx = f.setHeaders(ctx) + return f.Client.Publish(ctx, p, opts...) } // FromService wraps a client to inject service and auth metadata -func FromService(name string, c client.Client, fn func() auth.Auth) client.Client { - return &clientWrapper{ +func FromService(name string, c client.Client) client.Client { + return &fromServiceWrapper{ c, - fn, metadata.Metadata{ HeaderPrefix + "From-Service": name, }, @@ -95,6 +73,28 @@ func HandlerStats(stats stats.Stats) server.HandlerWrapper { } } +type traceWrapper struct { + client.Client + + name string + trace trace.Tracer +} + +func (c *traceWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { + newCtx, s := c.trace.Start(ctx, req.Service()+"."+req.Endpoint()) + + s.Type = trace.SpanTypeRequestOutbound + err := c.Client.Call(newCtx, req, rsp, opts...) + if err != nil { + s.Metadata["error"] = err.Error() + } + + // finish the trace + c.trace.Finish(s) + + return err +} + // TraceCall is a call tracing wrapper func TraceCall(name string, t trace.Tracer, c client.Client) client.Client { return &traceWrapper{ @@ -132,6 +132,104 @@ func TraceHandler(t trace.Tracer) server.HandlerWrapper { } } +type authWrapper struct { + client.Client + name string + id string + auth func() auth.Auth +} + +func (a *authWrapper) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error { + // parse the options + var options client.CallOptions + for _, o := range opts { + o(&options) + } + + // check to see if the authorization header has already been set. + // We dont't override the header unless the ServiceToken option has + // been specified or the header wasn't provided + if _, ok := metadata.Get(ctx, "Authorization"); ok && !options.ServiceToken { + return a.Client.Call(ctx, req, rsp, opts...) + } + + // if auth is nil we won't be able to get an access token, so we execute + // the request without one. + aa := a.auth() + if a == nil { + return a.Client.Call(ctx, req, rsp, opts...) + } + + // performs the call with the authorization token provided + callWithToken := func(token string) error { + ctx := metadata.Set(ctx, "Authorization", auth.BearerScheme+token) + return a.Client.Call(ctx, req, rsp, opts...) + } + + // check to see if we have a valid access token + aaOpts := aa.Options() + if aaOpts.Token != nil && aaOpts.Token.Expiry.Unix() > time.Now().Unix() { + return callWithToken(aaOpts.Token.AccessToken) + } + + // if we have a refresh token we can use this to generate another access token + if aaOpts.Token != nil { + tok, err := aa.Token(auth.WithToken(aaOpts.Token.RefreshToken)) + if err != nil { + return err + } + aa.Init(auth.ClientToken(tok)) + return callWithToken(tok.AccessToken) + } + + // if we have credentials we can generate a new token for the account + if len(aaOpts.ID) > 0 && len(aaOpts.Secret) > 0 { + tok, err := aa.Token(auth.WithCredentials(aaOpts.ID, aaOpts.Secret)) + if err != nil { + return err + } + aa.Init(auth.ClientToken(tok)) + return callWithToken(tok.AccessToken) + } + + // check to see if a token was provided in config, this is normally used for + // setting the token when calling via the cli + if token, err := config.Get("micro", "auth", "token"); err == nil && len(token) > 0 { + return callWithToken(token) + } + + // determine the type of service from the name. we do this so we can allocate + // different roles depending on the type of services. e.g. we don't want web + // services talking directly to the runtime. TODO: find a better way to determine + // the type of service + serviceType := "service" + if strings.Contains(a.name, "api") { + serviceType = "api" + } else if strings.Contains(a.name, "web") { + serviceType = "web" + } + + // generate a new auth account for the service + name := fmt.Sprintf("%v-%v", a.name, a.id) + acc, err := aa.Generate(name, auth.WithNamespace(aaOpts.Namespace), auth.WithRoles(serviceType)) + if err != nil { + return err + } + token, err := aa.Token(auth.WithCredentials(acc.ID, acc.Secret)) + if err != nil { + return err + } + aa.Init(auth.ClientToken(token)) + + // use the token to execute the request + return callWithToken(token.AccessToken) +} + +// AuthClient wraps requests with the auth header +func AuthClient(name string, id string, auth func() auth.Auth, c client.Client) client.Client { + return &authWrapper{c, name, id, auth} +} + // AuthHandler wraps a server handler to perform auth func AuthHandler(fn func() auth.Auth) server.HandlerWrapper { return func(h server.HandlerFunc) server.HandlerFunc { diff --git a/util/wrapper/wrapper_test.go b/util/wrapper/wrapper_test.go index 7fb99bf3..fa03af21 100644 --- a/util/wrapper/wrapper_test.go +++ b/util/wrapper/wrapper_test.go @@ -4,7 +4,6 @@ import ( "context" "testing" - "github.com/micro/go-micro/v2/auth" "github.com/micro/go-micro/v2/metadata" ) @@ -33,8 +32,7 @@ func TestWrapper(t *testing.T) { } for _, d := range testData { - c := &clientWrapper{ - auth: func() auth.Auth { return nil }, + c := &fromServiceWrapper{ headers: d.headers, }