1
0
mirror of https://github.com/rclone/rclone.git synced 2025-01-24 12:56:36 +02:00

lib/http: Simplify server.go to export an http server rather than an interface

This also makes the implementation public.
This commit is contained in:
Tom Mombourquette 2022-11-23 05:44:53 -04:00 committed by Nick Craig-Wood
parent 2a2fcf1012
commit ec7cc2b3c3
5 changed files with 33 additions and 37 deletions

View File

@ -103,7 +103,7 @@ control the stats printing.
type serveCmd struct {
f fs.Fs
vfs *vfs.VFS
server libhttp.Server
server *libhttp.Server
}
func run(ctx context.Context, f fs.Fs, opt Options) (*serveCmd, error) {

View File

@ -3,7 +3,7 @@ package serve
import (
"errors"
"html/template"
"io/ioutil"
"io"
"net/http"
"net/http/httptest"
"net/url"
@ -94,7 +94,7 @@ func TestError(t *testing.T) {
Error("potato", w, "sausage", err)
resp := w.Result()
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
body, _ := ioutil.ReadAll(resp.Body)
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, "sausage.\n", string(body))
}
@ -108,7 +108,7 @@ func TestServe(t *testing.T) {
d.Serve(w, r)
resp := w.Result()
assert.Equal(t, http.StatusOK, resp.StatusCode)
body, _ := ioutil.ReadAll(resp.Body)
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, `<!DOCTYPE html>
<html lang="en">
<head>

View File

@ -1,7 +1,7 @@
package serve
import (
"io/ioutil"
"io"
"net/http"
"net/http/httptest"
"testing"
@ -17,7 +17,7 @@ func TestObjectBadMethod(t *testing.T) {
Object(w, r, o)
resp := w.Result()
assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
body, _ := ioutil.ReadAll(resp.Body)
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, "Method Not Allowed\n", string(body))
}
@ -30,7 +30,7 @@ func TestObjectHEAD(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "5", resp.Header.Get("Content-Length"))
assert.Equal(t, "bytes", resp.Header.Get("Accept-Ranges"))
body, _ := ioutil.ReadAll(resp.Body)
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, "", string(body))
}
@ -43,7 +43,7 @@ func TestObjectGET(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "5", resp.Header.Get("Content-Length"))
assert.Equal(t, "bytes", resp.Header.Get("Accept-Ranges"))
body, _ := ioutil.ReadAll(resp.Body)
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, "hello", string(body))
}
@ -58,7 +58,7 @@ func TestObjectRange(t *testing.T) {
assert.Equal(t, "3", resp.Header.Get("Content-Length"))
assert.Equal(t, "bytes", resp.Header.Get("Accept-Ranges"))
assert.Equal(t, "bytes 3-5/10", resp.Header.Get("Content-Range"))
body, _ := ioutil.ReadAll(resp.Body)
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, "345", string(body))
}
@ -71,6 +71,6 @@ func TestObjectBadRange(t *testing.T) {
resp := w.Result()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
assert.Equal(t, "10", resp.Header.Get("Content-Length"))
body, _ := ioutil.ReadAll(resp.Body)
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, "Bad Request\n", string(body))
}

View File

@ -121,16 +121,6 @@ func DefaultCfg() Config {
}
}
// Server interface of http server
type Server interface {
Router() chi.Router
Serve()
Shutdown() error
HTMLTemplate() *template.Template
URLs() []string
Wait()
}
type instance struct {
url string
listener net.Listener
@ -145,7 +135,8 @@ func (s instance) serve(wg *sync.WaitGroup) {
}
}
type server struct {
// Server contains info about the running http server
type Server struct {
wg sync.WaitGroup
mux chi.Router
tlsConfig *tls.Config
@ -157,25 +148,25 @@ type server struct {
}
// Option allows customizing the server
type Option func(*server)
type Option func(*Server)
// WithAuth option initializes the appropriate auth middleware
func WithAuth(cfg AuthConfig) Option {
return func(s *server) {
return func(s *Server) {
s.auth = cfg
}
}
// WithConfig option applies the Config to the server, overriding defaults
func WithConfig(cfg Config) Option {
return func(s *server) {
return func(s *Server) {
s.cfg = cfg
}
}
// WithTemplate option allows the parsing of a template
func WithTemplate(cfg TemplateConfig) Option {
return func(s *server) {
return func(s *Server) {
s.template = &cfg
}
}
@ -184,12 +175,17 @@ func WithTemplate(cfg TemplateConfig) Option {
// This function is provided if the default http server does not meet a services requirements and should not generally be used
// A http server can listen using multiple listeners. For example, a listener for port 80, and a listener for port 443.
// tlsListeners are ignored if opt.TLSKey is not provided
func NewServer(ctx context.Context, options ...Option) (*server, error) {
s := &server{
func NewServer(ctx context.Context, options ...Option) (*Server, error) {
s := &Server{
mux: chi.NewRouter(),
cfg: DefaultCfg(),
}
// Make sure default logger is logging where everything else is
// middleware.DefaultLogger = middleware.RequestLogger(&middleware.DefaultLogFormatter{Logger: log.Default(), NoColor: true})
// Log requests
// s.mux.Use(middleware.Logger)
for _, opt := range options {
opt(s)
}
@ -275,7 +271,7 @@ func NewServer(ctx context.Context, options ...Option) (*server, error) {
return s, nil
}
func (s *server) initAuth() {
func (s *Server) initAuth() {
if s.auth.CustomAuthFn != nil {
s.mux.Use(MiddlewareAuthCustom(s.auth.CustomAuthFn, s.auth.Realm))
return
@ -292,7 +288,7 @@ func (s *server) initAuth() {
}
}
func (s *server) initTemplate() error {
func (s *Server) initTemplate() error {
if s.template == nil {
return nil
}
@ -317,7 +313,7 @@ var (
ErrTLSParseCA = errors.New("unable to parse client certificate authority")
)
func (s *server) initTLS() error {
func (s *Server) initTLS() error {
if s.cfg.TLSCert == "" && s.cfg.TLSKey == "" && len(s.cfg.TLSCertBody) == 0 && len(s.cfg.TLSKeyBody) == 0 {
return nil
}
@ -383,7 +379,7 @@ func (s *server) initTLS() error {
}
// Serve starts the HTTP server on each listener
func (s *server) Serve() {
func (s *Server) Serve() {
s.wg.Add(len(s.instances))
for _, ii := range s.instances {
// TODO: decide how/when to log listening url
@ -393,17 +389,17 @@ func (s *server) Serve() {
}
// Wait blocks while the server is serving requests
func (s *server) Wait() {
func (s *Server) Wait() {
s.wg.Wait()
}
// Router returns the server base router
func (s *server) Router() chi.Router {
func (s *Server) Router() chi.Router {
return s.mux
}
// Shutdown gracefully shuts down the server
func (s *server) Shutdown() error {
func (s *Server) Shutdown() error {
ctx := context.Background()
for _, ii := range s.instances {
if err := ii.httpServer.Shutdown(ctx); err != nil {
@ -416,12 +412,12 @@ func (s *server) Shutdown() error {
}
// HTMLTemplate returns the parsed template, if WithTemplate option was passed.
func (s *server) HTMLTemplate() *template.Template {
func (s *Server) HTMLTemplate() *template.Template {
return s.htmlTemplate
}
// URLs returns all configured URLS
func (s *server) URLs() []string {
func (s *Server) URLs() []string {
var out []string
for _, ii := range s.instances {
if ii.listener.Addr().Network() == "unix" {

View File

@ -26,7 +26,7 @@ func testExpectRespBody(t *testing.T, resp *http.Response, expected []byte) {
require.Equal(t, expected, body)
}
func testGetServerURL(t *testing.T, s Server) string {
func testGetServerURL(t *testing.T, s *Server) string {
urls := s.URLs()
require.GreaterOrEqual(t, len(urls), 1, "server should return at least one url")
return urls[0]