Bläddra i källkod

重构HTTP服务器和SSL服务器代码

将`version1`重命名为`httpdemov1`,并更新相关导入路径。修改了HTTP和HTTPS服务器的启动和停止逻辑,增加了对`go-proxyproto`的支持,并改进了证书监听功能。
SongZihuan 2 månader sedan
förälder
incheckning
baaa6f524e

+ 1 - 0
go.mod

@@ -31,6 +31,7 @@ require (
 	github.com/modern-go/reflect2 v1.0.2 // indirect
 	github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect
 	github.com/pelletier/go-toml/v2 v2.2.2 // indirect
+	github.com/pires/go-proxyproto v0.8.0 // indirect
 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 	github.com/ugorji/go/codec v1.2.12 // indirect
 	golang.org/x/arch v0.8.0 // indirect

+ 2 - 0
go.sum

@@ -79,6 +79,8 @@ github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:Ff
 github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
 github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
 github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
+github.com/pires/go-proxyproto v0.8.0 h1:5unRmEAPbHXHuLjDg01CxJWf91cw3lKHc/0xzKpXEe0=
+github.com/pires/go-proxyproto v0.8.0/go.mod h1:iknsfgnH8EkjrMeMyvfKByp9TiBZCKZM0jx2xmKqnVY=
 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=

+ 6 - 10
src/certssl/main.go

@@ -61,26 +61,22 @@ type NewCert struct {
 	Error             error
 }
 
-func WatchCertificate(dir string, email string, aliyunAccessKey string, aliyunAccessSecret string, domain string, oldCert *x509.Certificate, stopchan chan bool, newchan chan NewCert) error {
+func WatchCertificate(dir string, email string, aliyunAccessKey string, aliyunAccessSecret string, domain string, oldCert *x509.Certificate, stopchan chan bool, newCertChan chan NewCert) error {
+	ticker := time.Tick(1 * time.Hour)
+
 	for {
 		select {
 		case <-stopchan:
-			newchan <- NewCert{
-				PrivateKey:  nil,
-				Certificate: nil,
-				Error:       nil,
-			}
-			close(stopchan)
 			return nil
-		default:
+		case <-ticker:
 			privateKey, cert, cacert, err := watchCertificate(dir, email, aliyunAccessKey, aliyunAccessSecret, domain, oldCert)
 			if err != nil {
-				newchan <- NewCert{
+				newCertChan <- NewCert{
 					Error: fmt.Errorf("watch cert failed: %s", err.Error()),
 				}
 			} else if privateKey != nil && cert != nil && cacert != nil {
 				oldCert = cert
-				newchan <- NewCert{
+				newCertChan <- NewCert{
 					PrivateKey:        privateKey,
 					Certificate:       cert,
 					IssuerCertificate: cacert,

+ 10 - 0
src/cmd/httpdemo/httpdemov1/main.go

@@ -0,0 +1,10 @@
+package main
+
+import (
+	"github.com/SongZihuan/http-demo/src/mainfunc/httpdemo"
+	"os"
+)
+
+func main() {
+	os.Exit(httpdemo.MainV1())
+}

+ 0 - 10
src/cmd/version1/main.go

@@ -1,10 +0,0 @@
-package main
-
-import (
-	"github.com/SongZihuan/http-demo/src/mainfunc"
-	"os"
-)
-
-func main() {
-	os.Exit(mainfunc.MainV1())
-}

+ 45 - 2
src/httpserver/server.go

@@ -1,14 +1,19 @@
 package httpserver
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"github.com/SongZihuan/http-demo/src/engine"
 	"github.com/SongZihuan/http-demo/src/flagparser"
+	"github.com/pires/go-proxyproto"
+	"net"
 	"net/http"
+	"time"
 )
 
 var HttpServer *http.Server = nil
+var HttpListener net.Listener = nil
 var HttpAddress string
 
 var ErrStop = fmt.Errorf("http server error")
@@ -24,11 +29,49 @@ func InitHttpServer() error {
 	return nil
 }
 
-func RunServer() error {
+func RunServer() (err error) {
+	tcpListener, err := net.Listen("tcp", HttpServer.Addr)
+	if err != nil {
+		return err
+	}
+
+	proxyListener := &proxyproto.Listener{
+		Listener:          tcpListener,
+		ReadHeaderTimeout: 10 * time.Second,
+	}
+
+	HttpListener = proxyListener
+
+	defer func() {
+		_ = HttpListener.Close()
+
+		HttpServer = nil
+		HttpListener = nil
+	}()
+
 	fmt.Printf("http server start at %s\n", HttpAddress)
-	err := HttpServer.ListenAndServe()
+	err = HttpServer.Serve(HttpListener)
 	if err != nil && errors.Is(err, http.ErrServerClosed) {
 		return ErrStop
 	}
 	return fmt.Errorf("http server error: %s", err)
 }
+
+func StopServer() (err error) {
+	if HttpServer == nil {
+		return nil
+	}
+
+	ctx, cancelFunc := context.WithCancel(context.Background())
+	defer cancelFunc()
+
+	err = HttpServer.Shutdown(ctx)
+	if err != nil {
+		return err
+	}
+
+	HttpServer = nil
+	HttpListener = nil
+
+	return nil
+}

+ 131 - 54
src/httpsslserver/server.go

@@ -10,12 +10,15 @@ import (
 	"github.com/SongZihuan/http-demo/src/certssl"
 	"github.com/SongZihuan/http-demo/src/engine"
 	"github.com/SongZihuan/http-demo/src/flagparser"
+	"github.com/pires/go-proxyproto"
+	"net"
 	"net/http"
 	"sync"
 	"time"
 )
 
 var HttpSSLServer *http.Server = nil
+var HttpSSLListener net.Listener = nil
 var HttpSSLAddress string
 var HttpSSLDomain string
 var HttpSSLEmail string
@@ -62,6 +65,56 @@ func initHttpSSLServer() (err error) {
 		return fmt.Errorf("init https server error: get cert.raw error, return nil, unknown reason")
 	}
 
+	HttpSSLServer = &http.Server{
+		Addr:    HttpSSLAddress,
+		Handler: engine.Engine,
+	}
+
+	return nil
+}
+
+func RunServer() error {
+	watchStopChan := make(chan bool)
+	defer close(watchStopChan)
+
+	watchCertificate(watchStopChan)
+
+	err := runServer()
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func StopServer() (err error) {
+	if HttpSSLServer == nil {
+		return nil
+	}
+
+	ctx, cancelFunc := context.WithCancel(context.Background())
+	defer cancelFunc()
+
+	err = HttpSSLServer.Shutdown(ctx)
+	if err != nil {
+		return err
+	}
+
+	HttpSSLServer = nil
+	HttpSSLListener = nil
+
+	return nil
+}
+
+func loadListener() (err error) {
+	if PrivateKey == nil || Certificate == nil || IssuerCertificate == nil {
+		return fmt.Errorf("init https server error: get key and cert error, return nil, unknown reason")
+	}
+
+	if Certificate.Raw == nil || len(Certificate.Raw) == 0 || IssuerCertificate.Raw == nil || len(IssuerCertificate.Raw) == 0 {
+		return fmt.Errorf("init https server error: get cert.raw error, return nil, unknown reason")
+	}
+
 	tlsConfig := &tls.Config{
 		Certificates: []tls.Certificate{{
 			Certificate: [][]byte{Certificate.Raw, IssuerCertificate.Raw}, // Raw包含 DER 编码的证书
@@ -71,81 +124,105 @@ func initHttpSSLServer() (err error) {
 		MinVersion: tls.VersionTLS12,
 	}
 
-	HttpSSLServer = &http.Server{
-		Addr:      HttpSSLAddress,
-		Handler:   engine.Engine,
-		TLSConfig: tlsConfig,
+	tcpListener, err := net.Listen("tcp", HttpSSLServer.Addr)
+	if err != nil {
+		return err
 	}
 
-	return nil
-}
+	proxyListener := &proxyproto.Listener{
+		Listener:          tcpListener,
+		ReadHeaderTimeout: 10 * time.Second,
+	}
 
-func RunServer() error {
-	stopchan := make(chan bool)
-	WatchCertificate(stopchan)
-	err := runServer()
-	stopchan <- true
-	return err
+	tlsListener := tls.NewListener(proxyListener, tlsConfig)
+	HttpSSLListener = tlsListener
+
+	return nil
 }
 
 func runServer() error {
-	fmt.Printf("https server start at %s\n", HttpSSLAddress)
-ListenCycle:
+	defer func() {
+		HttpSSLServer = nil
+		HttpSSLListener = nil
+	}()
+
 	for {
-		err := HttpSSLServer.ListenAndServeTLS("", "")
-		if err != nil && errors.Is(err, http.ErrServerClosed) {
-			if ReloadMutex.TryLock() {
-				ReloadMutex.Unlock()
-				return ErrStop
+		err := func() error {
+			err := loadListener()
+			if err != nil {
+				return err
+			}
+			defer func() {
+				_ = HttpSSLListener.Close()
+			}()
+
+			fmt.Printf("https server start at %s\n", HttpSSLAddress)
+			err = HttpSSLServer.Serve(HttpSSLListener)
+			if err != nil && errors.Is(err, http.ErrServerClosed) {
+				if ReloadMutex.TryLock() {
+					ReloadMutex.Unlock()
+					return ErrStop
+				}
+				ReloadMutex.Lock()
+				ReloadMutex.Unlock() // 等待证书更换完毕
+				return nil
+			} else if err != nil {
+				return fmt.Errorf("https server error: %s", err.Error())
 			}
-			ReloadMutex.Lock()
-			ReloadMutex.Unlock() // 等待证书更换完毕
-			continue ListenCycle
-		} else if err != nil {
-			return fmt.Errorf("https server error: %s", err.Error())
+
+			return nil
+		}()
+		if err != nil {
+			return err
 		}
 	}
 }
 
-func WatchCertificate(stopchan chan bool) {
-	newchan := make(chan certssl.NewCert)
+func watchCertificate(stopchan chan bool) {
+	newCertChan := make(chan certssl.NewCert)
 
 	go func() {
-		err := certssl.WatchCertificate(HttpSSLCertDir, HttpSSLEmail, HttpSSLAliyunAccessKey, HttpSSLAliyunAccessSecret, HttpSSLDomain, Certificate, stopchan, newchan)
+		err := certssl.WatchCertificate(HttpSSLCertDir, HttpSSLEmail, HttpSSLAliyunAccessKey, HttpSSLAliyunAccessSecret, HttpSSLDomain, Certificate, stopchan, newCertChan)
 		if err != nil {
 			fmt.Printf("watch https cert server error: %s", err.Error())
 		}
 	}()
 
 	go func() {
-		select {
-		case res := <-newchan:
-			if res.Certificate == nil && res.PrivateKey == nil && res.Error == nil {
-				close(newchan)
+		close(newCertChan)
+
+		for {
+			select {
+			case <-stopchan:
 				return
-			} else if res.Error != nil {
-				fmt.Printf("https cert reload server error: %s", res.Error.Error())
-			} else if res.PrivateKey != nil && res.Certificate != nil && res.IssuerCertificate != nil {
-				func() {
-					ReloadMutex.Lock()
-					defer ReloadMutex.Unlock()
-
-					ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
-					defer cancel()
-
-					err := HttpSSLServer.Shutdown(ctx)
-					if err != nil {
-						fmt.Printf("https server reload shutdown error: %s", err.Error())
-					}
-
-					PrivateKey = res.PrivateKey
-					Certificate = res.Certificate
-					IssuerCertificate = res.IssuerCertificate
-					err = initHttpSSLServer()
-					if err != nil {
-						fmt.Printf("https server reload init error: %s", err.Error())
-					}
-				}()
+			case res := <-newCertChan:
+				if res.Certificate == nil && res.PrivateKey == nil && res.Error == nil {
+					close(newCertChan)
+					return
+				} else if res.Error != nil {
+					fmt.Printf("https cert reload server error: %s", res.Error.Error())
+				} else if res.PrivateKey != nil && res.Certificate != nil && res.IssuerCertificate != nil {
+					func() {
+						ReloadMutex.Lock()
+						defer ReloadMutex.Unlock()
+
+						ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+						defer cancel()
+
+						err := HttpSSLServer.Shutdown(ctx)
+						if err != nil {
+							fmt.Printf("https server reload shutdown error: %s", err.Error())
+						}
+
+						PrivateKey = res.PrivateKey
+						Certificate = res.Certificate
+						IssuerCertificate = res.IssuerCertificate
+						err = initHttpSSLServer()
+						if err != nil {
+							fmt.Printf("https server reload init error: %s", err.Error())
+						}
+					}()
+				}
 			}
 		}
 	}()

+ 17 - 1
src/mainfunc/version1.go → src/mainfunc/httpdemo/version1.go

@@ -1,4 +1,4 @@
-package mainfunc
+package httpdemo
 
 import (
 	"errors"
@@ -8,6 +8,7 @@ import (
 	"github.com/SongZihuan/http-demo/src/httpserver"
 	"github.com/SongZihuan/http-demo/src/httpsslserver"
 	"github.com/SongZihuan/http-demo/src/signalchan"
+	"sync"
 )
 
 func MainV1() (exitcode int) {
@@ -105,6 +106,21 @@ func MainV1() (exitcode int) {
 		}()
 	}
 
+	defer func() {
+		var wg sync.WaitGroup
+
+		go func() {
+			defer wg.Done()
+			_ = httpserver.StopServer()
+		}()
+
+		go func() {
+			defer wg.Done()
+			_ = httpsslserver.StopServer()
+		}()
+
+	}()
+
 	select {
 	case <-signalchan.SignalChan:
 		fmt.Printf("Server closed: safe\n")