From 7206d5f964ee89a1b09c3566bbd47674e847b5bf Mon Sep 17 00:00:00 2001 From: Ben Toogood Date: Tue, 7 Apr 2020 09:40:40 +0100 Subject: [PATCH] Add Namespace to CombinedAuthHandler --- api/server/auth/auth.go | 80 ++++++++++++++++++------------------ api/server/auth/auth_test.go | 34 +++++++++++++++ api/server/http/http.go | 2 +- api/server/options.go | 23 +++++++---- 4 files changed, 91 insertions(+), 48 deletions(-) create mode 100644 api/server/auth/auth_test.go diff --git a/api/server/auth/auth.go b/api/server/auth/auth.go index f0e38562..07d28068 100644 --- a/api/server/auth/auth.go +++ b/api/server/auth/auth.go @@ -15,38 +15,32 @@ import ( ) // CombinedAuthHandler wraps a server and authenticates requests -func CombinedAuthHandler(namespace string, r resolver.Resolver, h http.Handler) http.Handler { +func CombinedAuthHandler(prefix, namespace string, r resolver.Resolver, h http.Handler) http.Handler { if r == nil { r = path.NewResolver() } - if len(namespace) == 0 { - namespace = "go.micro" - } return authHandler{ - handler: h, - resolver: r, - auth: auth.DefaultAuth, - namespace: namespace, + handler: h, + resolver: r, + auth: auth.DefaultAuth, + servicePrefix: prefix, + namespace: namespace, } } type authHandler struct { - handler http.Handler - auth auth.Auth - resolver resolver.Resolver - namespace string + handler http.Handler + auth auth.Auth + resolver resolver.Resolver + namespace string + servicePrefix string } func (h authHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - // Determine the namespace - namespace, err := namespaceFromRequest(req) - if err != nil { - logger.Error(err) - namespace = auth.DefaultNamespace - } - - // Set the namespace in the header + // Determine the namespace and set it in the header + namespace := h.namespaceFromRequest(req) + fmt.Printf("Namespace is %v\n", namespace) req.Header.Set(auth.NamespaceKey, namespace) // Extract the token from the request @@ -96,7 +90,7 @@ func (h authHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } // construct the resource name, e.g. home => go.micro.web.home - resName := h.namespace + resName := h.servicePrefix if len(endpoint.Name) > 0 { resName = resName + "." + endpoint.Name } @@ -138,39 +132,47 @@ func (h authHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { http.Redirect(w, req, loginWithRedirect, http.StatusTemporaryRedirect) } -func namespaceFromRequest(req *http.Request) (string, error) { - // needed to tmp debug host in prod. will be removed. - logger.Infof("Host is '%v'; URL Host is '%v'; URL Hostname is '%v'", req.Host, req.URL.Host, req.URL.Hostname()) +func (h authHandler) namespaceFromRequest(req *http.Request) string { + // check to see what the provided namespace is, we only do + // domain mapping if the namespace is set to 'domain' + if h.namespace != "domain" { + return h.namespace + } // determine the host, e.g. dev.micro.mu:8080 - host := req.URL.Hostname() - if len(host) == 0 { - // fallback to req.Host - host, _, _ = net.SplitHostPort(req.Host) + var host string + if h, _, err := net.SplitHostPort(req.Host); err == nil { + host = h // host does contain a port + } else { + host = req.Host // host does not contain a port + } + + // check for the micro.mu domain + if strings.HasSuffix(host, "micro.mu") { + return auth.DefaultNamespace } // check for an ip address if net.ParseIP(host) != nil { - return auth.DefaultNamespace, nil + return auth.DefaultNamespace } // check for dev enviroment if host == "localhost" || host == "127.0.0.1" { - return auth.DefaultNamespace, nil + return auth.DefaultNamespace } // if host is not a subdomain, deturn default namespace comps := strings.Split(host, ".") - if len(comps) != 3 { - return auth.DefaultNamespace, nil + if len(comps) < 3 { + return auth.DefaultNamespace } - // check for the micro.mu domain - domain := fmt.Sprintf("%v.%v", comps[1], comps[2]) - if domain == "micro.mu" { - return auth.DefaultNamespace, nil + // return the reversed subdomain as the namespace + nComps := comps[0 : len(comps)-2] + for i := len(nComps)/2 - 1; i >= 0; i-- { + opp := len(nComps) - 1 - i + nComps[i], nComps[opp] = nComps[opp], nComps[i] } - - // return the subdomain as the host - return comps[0], nil + return strings.Join(nComps, ".") } diff --git a/api/server/auth/auth_test.go b/api/server/auth/auth_test.go new file mode 100644 index 00000000..cf454e28 --- /dev/null +++ b/api/server/auth/auth_test.go @@ -0,0 +1,34 @@ +package auth + +import ( + "net/http" + "testing" + + "github.com/micro/go-micro/v2/auth" +) + +func TestNamespaceFromRequest(t *testing.T) { + tt := []struct { + Host string + Namespace string + }{ + {Host: "micro.mu", Namespace: auth.DefaultNamespace}, + {Host: "web.micro.mu", Namespace: auth.DefaultNamespace}, + {Host: "api.micro.mu", Namespace: auth.DefaultNamespace}, + {Host: "myapp.com", Namespace: auth.DefaultNamespace}, + {Host: "staging.myapp.com", Namespace: "staging"}, + {Host: "staging.myapp.m3o.app", Namespace: "myapp.staging"}, + {Host: "127.0.0.1", Namespace: auth.DefaultNamespace}, + {Host: "localhost", Namespace: auth.DefaultNamespace}, + {Host: "81.151.101.146", Namespace: auth.DefaultNamespace}, + } + + for _, tc := range tt { + t.Run(tc.Host, func(t *testing.T) { + ns := namespaceFromRequest(&http.Request{Host: tc.Host}) + if ns != tc.Namespace { + t.Errorf("Expected namespace %v for host %v, actually got %v", tc.Namespace, tc.Host, ns) + } + }) + } +} diff --git a/api/server/http/http.go b/api/server/http/http.go index 2599d2db..02238aa9 100644 --- a/api/server/http/http.go +++ b/api/server/http/http.go @@ -53,7 +53,7 @@ func (s *httpServer) Init(opts ...server.Option) error { func (s *httpServer) Handle(path string, handler http.Handler) { h := handlers.CombinedLoggingHandler(os.Stdout, handler) - h = auth.CombinedAuthHandler(s.opts.Namespace, s.opts.Resolver, handler) + h = auth.CombinedAuthHandler(s.opts.ServiceNamespace, s.opts.Namespace, s.opts.Resolver, handler) if s.opts.EnableCORS { h = cors.CombinedCORSHandler(h) diff --git a/api/server/options.go b/api/server/options.go index 5d167ced..9d429436 100644 --- a/api/server/options.go +++ b/api/server/options.go @@ -10,14 +10,15 @@ import ( type Option func(o *Options) type Options struct { - EnableACME bool - EnableCORS bool - ACMEProvider acme.Provider - EnableTLS bool - ACMEHosts []string - TLSConfig *tls.Config - Namespace string - Resolver resolver.Resolver + EnableACME bool + EnableCORS bool + ACMEProvider acme.Provider + EnableTLS bool + ACMEHosts []string + TLSConfig *tls.Config + Resolver resolver.Resolver + Namespace string + ServiceNamespace string } func EnableCORS(b bool) Option { @@ -56,6 +57,12 @@ func TLSConfig(t *tls.Config) Option { } } +func ServiceNamespace(n string) Option { + return func(o *Options) { + o.ServiceNamespace = n + } +} + func Namespace(n string) Option { return func(o *Options) { o.Namespace = n