瀏覽代碼

[update] add plugin config (#1180)

Signed-off-by: lihaowei <haoweili35@gmail.com>
Howie 3 年之前
父節點
當前提交
cd1f8da13f
共有 4 個文件被更改,包括 58 次插入5 次删除
  1. 3 1
      rest/engine.go
  2. 8 4
      rest/internal/starter.go
  3. 10 0
      rest/server.go
  4. 37 0
      rest/server_test.go

+ 3 - 1
rest/engine.go

@@ -1,6 +1,7 @@
 package rest
 package rest
 
 
 import (
 import (
+	"crypto/tls"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
@@ -30,6 +31,7 @@ type engine struct {
 	middlewares          []Middleware
 	middlewares          []Middleware
 	shedder              load.Shedder
 	shedder              load.Shedder
 	priorityShedder      load.Shedder
 	priorityShedder      load.Shedder
+	tlsConfig            *tls.Config
 }
 }
 
 
 func newEngine(c RestConf) *engine {
 func newEngine(c RestConf) *engine {
@@ -70,7 +72,7 @@ func (s *engine) StartWithRouter(router httpx.Router) error {
 		return internal.StartHttp(s.conf.Host, s.conf.Port, router)
 		return internal.StartHttp(s.conf.Host, s.conf.Port, router)
 	}
 	}
 
 
-	return internal.StartHttps(s.conf.Host, s.conf.Port, s.conf.CertFile, s.conf.KeyFile, router)
+	return internal.StartHttps(s.conf.Host, s.conf.Port, s.conf.CertFile, s.conf.KeyFile, s.tlsConfig, router)
 }
 }
 
 
 func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
 func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,

+ 8 - 4
rest/internal/starter.go

@@ -2,6 +2,7 @@ package internal
 
 
 import (
 import (
 	"context"
 	"context"
+	"crypto/tls"
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
 
 
@@ -10,24 +11,27 @@ import (
 
 
 // StartHttp starts a http server.
 // StartHttp starts a http server.
 func StartHttp(host string, port int, handler http.Handler) error {
 func StartHttp(host string, port int, handler http.Handler) error {
-	return start(host, port, handler, func(srv *http.Server) error {
+	return start(host, port, handler, nil, func(srv *http.Server) error {
 		return srv.ListenAndServe()
 		return srv.ListenAndServe()
 	})
 	})
 }
 }
 
 
 // StartHttps starts a https server.
 // StartHttps starts a https server.
-func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler) error {
-	return start(host, port, handler, func(srv *http.Server) error {
+func StartHttps(host string, port int, certFile, keyFile string, tlsConfig *tls.Config, handler http.Handler) error {
+	return start(host, port, handler, tlsConfig, func(srv *http.Server) error {
 		// certFile and keyFile are set in buildHttpsServer
 		// certFile and keyFile are set in buildHttpsServer
 		return srv.ListenAndServeTLS(certFile, keyFile)
 		return srv.ListenAndServeTLS(certFile, keyFile)
 	})
 	})
 }
 }
 
 
-func start(host string, port int, handler http.Handler, run func(srv *http.Server) error) (err error) {
+func start(host string, port int, handler http.Handler, tlsConfig *tls.Config, run func(srv *http.Server) error) (err error) {
 	server := &http.Server{
 	server := &http.Server{
 		Addr:    fmt.Sprintf("%s:%d", host, port),
 		Addr:    fmt.Sprintf("%s:%d", host, port),
 		Handler: handler,
 		Handler: handler,
 	}
 	}
+	if tlsConfig != nil {
+		server.TLSConfig = tlsConfig
+	}
 	waitForCalled := proc.AddWrapUpListener(func() {
 	waitForCalled := proc.AddWrapUpListener(func() {
 		server.Shutdown(context.Background())
 		server.Shutdown(context.Background())
 	})
 	})

+ 10 - 0
rest/server.go

@@ -1,6 +1,7 @@
 package rest
 package rest
 
 
 import (
 import (
+	"crypto/tls"
 	"log"
 	"log"
 	"net/http"
 	"net/http"
 
 
@@ -193,6 +194,15 @@ func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
 	}
 	}
 }
 }
 
 
+// WithTLSConfig returns a RunOption that with given tls config.
+func WithTLSConfig(cipherSuites []uint16) RunOption {
+	return func(engine *Server) {
+		engine.ngin.tlsConfig = &tls.Config{
+			CipherSuites: cipherSuites,
+		}
+	}
+}
+
 // WithUnsignedCallback returns a RunOption that with given unsigned callback set.
 // WithUnsignedCallback returns a RunOption that with given unsigned callback set.
 func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
 func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
 	return func(engine *Server) {
 	return func(engine *Server) {

+ 37 - 0
rest/server_test.go

@@ -1,6 +1,7 @@
 package rest
 package rest
 
 
 import (
 import (
+	"crypto/tls"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
@@ -217,3 +218,39 @@ func TestWithPriority(t *testing.T) {
 	WithPriority()(&fr)
 	WithPriority()(&fr)
 	assert.True(t, fr.priority)
 	assert.True(t, fr.priority)
 }
 }
+
+func TestWithTLSConfig(t *testing.T) {
+	const configYaml = `
+Name: foo
+Port: 54321
+`
+	var cnf RestConf
+	assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
+
+	testConfig := []uint16{
+		tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+	}
+
+	testCases := []struct {
+		c    RestConf
+		opts []RunOption
+		res  *tls.Config
+	}{
+		{
+			c:    cnf,
+			opts: []RunOption{WithTLSConfig(testConfig)},
+			res:  &tls.Config{CipherSuites: testConfig},
+		},
+		{
+			c:    cnf,
+			opts: []RunOption{WithUnsignedCallback(nil)},
+			res:  nil,
+		},
+	}
+
+	for _, testCase := range testCases {
+		srv, err := NewServer(testCase.c, testCase.opts...)
+		assert.Nil(t, err)
+		assert.Equal(t, srv.ngin.tlsConfig, testCase.res)
+	}
+}