Преглед изворни кода

更新配置字段命名并重构代码

更新了多个配置文件中的字段命名,使其更具可读性,并对部分代码进行了重构以提高可维护性。删除了旧版本的主程序文件,并添加了新的命令文件以支持新版本。同时,增加了对ProxyProtocol的支持。
SongZihuan пре 2 месеци
родитељ
комит
60bda723b0

+ 2 - 0
.gitignore

@@ -12,3 +12,5 @@ testdata
 remote-testdata
 
 pkg
+
+go-remote.sh

+ 1 - 0
go.mod

@@ -20,6 +20,7 @@ require (
 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 	github.com/modern-go/reflect2 v1.0.2 // indirect
 	github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect
+	github.com/pires/go-proxyproto v0.8.0 // indirect
 	golang.org/x/crypto v0.31.0 // indirect
 	golang.org/x/mod v0.22.0 // indirect
 	golang.org/x/net v0.33.0 // indirect

+ 2 - 0
go.sum

@@ -51,6 +51,8 @@ github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWb
 github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
 github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b h1:FfH+VrHHk6Lxt9HdVS0PXzSXFyS2NbZKXv33FYPol0A=
 github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b/go.mod h1:AC62GU6hc0BrNm+9RK9VSiwa/EUe1bkIeFORAMcHvJU=
+github.com/pires/go-proxyproto v0.8.0 h1:5unRmEAPbHXHuLjDg01CxJWf91cw3lKHc/0xzKpXEe0=
+github.com/pires/go-proxyproto v0.8.0/go.mod h1:iknsfgnH8EkjrMeMyvfKByp9TiBZCKZM0jx2xmKqnVY=
 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=

+ 7 - 11
src/certssl/main.go

@@ -61,26 +61,22 @@ type NewCert struct {
 	Error             error
 }
 
-func WatchCertificate(dir string, email string, aliyunAccessKey string, aliyunAccessSecret string, domain string, oldCert *x509.Certificate, stopchan chan bool, newchan chan NewCert) error {
+func WatchCertificate(dir string, email string, aliyunAccessKey string, aliyunAccessSecret string, domain string, oldCert *x509.Certificate, stopChan chan bool, newCertChan chan NewCert) error {
+	ticker := time.Tick(1 * time.Hour)
+
 	for {
 		select {
-		case <-stopchan:
-			newchan <- NewCert{
-				PrivateKey:  nil,
-				Certificate: nil,
-				Error:       nil,
-			}
-			close(stopchan)
+		case <-stopChan:
 			return nil
-		default:
+		case <-ticker:
 			privateKey, cert, cacert, err := watchCertificate(dir, email, aliyunAccessKey, aliyunAccessSecret, domain, oldCert)
 			if err != nil {
-				newchan <- NewCert{
+				newCertChan <- NewCert{
 					Error: fmt.Errorf("watch cert failed: %s", err.Error()),
 				}
 			} else if privateKey != nil && cert != nil && cacert != nil {
 				oldCert = cert
-				newchan <- NewCert{
+				newCertChan <- NewCert{
 					PrivateKey:        privateKey,
 					Certificate:       cert,
 					IssuerCertificate: cacert,

+ 10 - 0
src/cmd/huanproxy/huanproxyv1/main.go

@@ -0,0 +1,10 @@
+package main
+
+import (
+	"github.com/SongZihuan/huan-proxy/src/mainfunc/huanproxy"
+	"github.com/SongZihuan/huan-proxy/src/utils"
+)
+
+func main() {
+	utils.Exit(huanproxy.MainV1())
+}

+ 0 - 10
src/cmd/version1/main.go

@@ -1,10 +0,0 @@
-package main
-
-import (
-	"github.com/SongZihuan/huan-proxy/src/mainfunc"
-	"github.com/SongZihuan/huan-proxy/src/utils"
-)
-
-func main() {
-	utils.Exit(mainfunc.MainV1())
-}

+ 3 - 3
src/config/globalconfig.go

@@ -27,9 +27,9 @@ var levelMap = map[string]bool{
 
 type GlobalConfig struct {
 	Mode     string           `yaml:"mode"`
-	LogLevel string           `yaml:"loglevel"`
-	LogTag   utils.StringBool `yaml:"logtag"`
-	NotAbort utils.StringBool `yaml:"notabort"`
+	LogLevel string           `yaml:"log-level"`
+	LogTag   utils.StringBool `yaml:"log-tag"`
+	NotAbort utils.StringBool `yaml:"not-abort"`
 }
 
 func (g *GlobalConfig) SetDefault() {

+ 6 - 2
src/config/httpconfig.go

@@ -2,11 +2,13 @@ package config
 
 import (
 	"github.com/SongZihuan/huan-proxy/src/config/configerr"
+	"github.com/SongZihuan/huan-proxy/src/utils"
 )
 
 type HttpConfig struct {
-	Address        string `yaml:"address"`
-	StopWaitSecond int    `yaml:"stopwaitsecond"`
+	Address        string           `yaml:"address"`
+	StopWaitSecond int              `yaml:"stop-wait-second"`
+	ProxyProto     utils.StringBool `yaml:"proxy-proto"`
 }
 
 func (h *HttpConfig) SetDefault() {
@@ -17,6 +19,8 @@ func (h *HttpConfig) SetDefault() {
 	if h.StopWaitSecond <= 0 {
 		h.StopWaitSecond = 10
 	}
+
+	h.ProxyProto.SetDefaultEnable()
 }
 
 func (h *HttpConfig) Check() configerr.ConfigError {

+ 10 - 7
src/config/httpsconfig.go

@@ -12,13 +12,14 @@ const (
 )
 
 type HttpsConfig struct {
-	Address               string `yaml:"address"`
-	SSLEmail              string `json:"sslemail"`
-	SSLDomain             string `yaml:"ssldomain"`
-	SSLCertDir            string `yaml:"sslcertdir"`
-	AliyunDNSAccessKey    string `yaml:"aliyundnsaccesskey"`
-	AliyunDNSAccessSecret string `yaml:"aliyundnsaccesssecret"`
-	StopWaitSecond        int    `yaml:"stopwaitsecond"`
+	Address               string           `yaml:"address"`
+	SSLEmail              string           `json:"ssl-email"`
+	SSLDomain             string           `yaml:"ssl-domain"`
+	SSLCertDir            string           `yaml:"ssl-cert-dir"`
+	AliyunDNSAccessKey    string           `yaml:"aliyun-dns-access-key"`
+	AliyunDNSAccessSecret string           `yaml:"aliyun-dns-access-secret"`
+	StopWaitSecond        int              `yaml:"stop-wait-second"`
+	ProxyProto            utils.StringBool `yaml:"proxy-proto"`
 }
 
 func (h *HttpsConfig) SetDefault() {
@@ -39,6 +40,8 @@ func (h *HttpsConfig) SetDefault() {
 	if h.StopWaitSecond <= 0 {
 		h.StopWaitSecond = 10
 	}
+
+	h.ProxyProto.SetDefaultEnable()
 }
 
 func (h *HttpsConfig) Check() configerr.ConfigError {

+ 8 - 8
src/config/rules/action/api/api.go

@@ -11,15 +11,15 @@ import (
 
 type RuleAPIConfig struct {
 	Address   string                `yaml:"address"`
-	AddPath   string                `yaml:"addpath"`
-	SubPath   string                `yaml:"subpath"`
+	AddPath   string                `yaml:"add-path"`
+	SubPath   string                `yaml:"sub-path"`
 	Rewrite   rewrite.RewriteConfig `yaml:"rewrite"`
-	HeaderSet []*ReqHeaderConfig    `yaml:"headerset"`
-	HeaderAdd []*ReqHeaderConfig    `yaml:"headeradd"`
-	HeaderDel []*ReqHeaderDelConfig `yaml:"headerdel"`
-	QuerySet  []*QueryConfig        `yaml:"queryset"`
-	QueryAdd  []*QueryConfig        `yaml:"queryadd"`
-	QueryDel  []*QueryDelConfig     `yaml:"querydel"`
+	HeaderSet []*ReqHeaderConfig    `yaml:"header-set"`
+	HeaderAdd []*ReqHeaderConfig    `yaml:"header-add"`
+	HeaderDel []*ReqHeaderDelConfig `yaml:"header-del"`
+	QuerySet  []*QueryConfig        `yaml:"query-set"`
+	QueryAdd  []*QueryConfig        `yaml:"query-add"`
+	QueryDel  []*QueryDelConfig     `yaml:"query-del"`
 	Via       string                `yaml:"via"`
 }
 

+ 4 - 4
src/config/rules/action/cors/corsconfig.go

@@ -10,10 +10,10 @@ const CorsMaxAgeSec = 86400
 const CorsDefaultMaxAgeSec = CorsMaxAgeSec
 
 type CorsConfig struct {
-	AllowCors      utils.StringBool `yaml:"allowcors"`
-	AllowOrigin    []string         `yaml:"alloworigin"`
-	AllowOriginReg []string         `yaml:"alloworiginres"`
-	MaxAgeSec      int              `yaml:"maxagesec"`
+	AllowCors        utils.StringBool `yaml:"allow-cors"`
+	AllowOrigin      []string         `yaml:"allow-origin"`
+	AllowOriginRegex []string         `yaml:"allow-origin-regex"`
+	MaxAgeSec        int              `yaml:"max-age-sec"`
 }
 
 func (c *CorsConfig) SetDefault() {

+ 5 - 5
src/config/rules/action/dir/dirconfig.go

@@ -7,11 +7,11 @@ import (
 )
 
 type RuleDirConfig struct {
-	BasePath   string                `yaml:"basepath"`
-	IndexFile  []*IndexFileConfig    `yaml:"indexfile"`
-	IgnoreFile []*IgnoreFileConfig   `yaml:"ignorefile"`
-	AddPath    string                `yaml:"addpath"`
-	SubPath    string                `yaml:"subpath"`
+	BasePath   string                `yaml:"base-path"`
+	IndexFile  []*IndexFileConfig    `yaml:"index-file"`
+	IgnoreFile []*IgnoreFileConfig   `yaml:"ignore-file"`
+	AddPath    string                `yaml:"add-path"`
+	SubPath    string                `yaml:"sub-path"`
 	Rewrite    rewrite.RewriteConfig `yaml:"rewrite"`
 	Cors       cors.CorsConfig       `yaml:"cors"`
 }

+ 5 - 5
src/config/rules/action/remotetrust/remotetrustconfig.go

@@ -7,20 +7,20 @@ import (
 )
 
 type RemoteTrustConfig struct {
-	RemoteTrust utils.StringBool `yaml:"remotetrust"`
-	TrustedIPs  []string         `yaml:"trustedips"`
+	RemoteTrusted utils.StringBool `yaml:"remote-trusted"`
+	TrustedIPs    []string         `yaml:"trusted-ips"`
 }
 
 func (p *RemoteTrustConfig) SetDefault() {
-	p.RemoteTrust.SetDefaultDisable()
+	p.RemoteTrusted.SetDefaultDisable()
 
-	if p.RemoteTrust.IsEnable() && len(p.TrustedIPs) == 0 {
+	if p.RemoteTrusted.IsEnable() && len(p.TrustedIPs) == 0 {
 		p.TrustedIPs = []string{"127.0.0.0/8", "::1"}
 	}
 }
 
 func (p *RemoteTrustConfig) Check() configerr.ConfigError {
-	if p.RemoteTrust.IsEnable() {
+	if p.RemoteTrusted.IsEnable() {
 		for _, ip := range p.TrustedIPs {
 			if !utils.ValidIPv4(ip) && !utils.ValidIPv6(ip) && !utils.IsValidIPv4CIDR(ip) && !utils.IsValidIPv6CIDR(ip) {
 				return configerr.NewConfigError(fmt.Sprintf("bad proxy trusts ip address: %s", ip))

+ 3 - 3
src/config/rules/action/respheader/respheaderconfig.go

@@ -5,9 +5,9 @@ import (
 )
 
 type SetRespHeaderConfig struct {
-	HeaderSet []*RespHeaderConfig    `yaml:"headerret"`
-	HeaderAdd []*RespHeaderConfig    `yaml:"headeradd"`
-	HeaderDel []*RespHeaderDelConfig `yaml:"headerdel"`
+	HeaderSet []*RespHeaderConfig    `yaml:"header-set"`
+	HeaderAdd []*RespHeaderConfig    `yaml:"header-add"`
+	HeaderDel []*RespHeaderDelConfig `yaml:"header-del"`
 }
 
 func (s *SetRespHeaderConfig) SetDefault() {

+ 2 - 2
src/config/rules/match/matchconfig.go

@@ -12,8 +12,8 @@ const (
 )
 
 type MatchConfig struct {
-	MatchType string `yaml:"matchtype"`
-	MatchPath string `yaml:"matchpath"`
+	MatchType string `yaml:"match-type"`
+	MatchPath string `yaml:"match-path"`
 }
 
 func (m *MatchConfig) SetDefault() {

+ 1 - 1
src/config/rules/ruleconfig.go

@@ -28,7 +28,7 @@ type RuleConfig struct {
 	Dir        dir.RuleDirConfig              `yaml:"dir"`
 	Api        api.RuleAPIConfig              `yaml:"api"`
 	Redirect   redirect.RuleRedirectConfig    `yaml:"redirect"`
-	RespHeader respheader.SetRespHeaderConfig `yaml:"respHeader"`
+	RespHeader respheader.SetRespHeaderConfig `yaml:"response-header"`
 }
 
 func (p *RuleConfig) SetDefault() {

+ 2 - 2
src/config/rulescompile/actioncompile/corscompile/corsconfig.go

@@ -27,8 +27,8 @@ func NewCorsCompileConfig(c *cors.CorsConfig) (*CorsCompileConfig, error) {
 		}, nil
 	}
 
-	regexps := make([]*regexp.Regexp, 0, len(c.AllowOriginReg))
-	for _, v := range c.AllowOriginReg {
+	regexps := make([]*regexp.Regexp, 0, len(c.AllowOriginRegex))
+	for _, v := range c.AllowOriginRegex {
 		reg, err := regexp.Compile(v)
 		if err != nil {
 			return nil, err

+ 1 - 1
src/config/rulescompile/actioncompile/remotetrustcompile/remotetrustcompileconfig.go

@@ -8,7 +8,7 @@ type RemoteTrustCompileConfig struct {
 }
 
 func NewRemoteTrustCompileConfig(r *remotetrust.RemoteTrustConfig) (*RemoteTrustCompileConfig, error) {
-	if r.RemoteTrust.IsDisable(false) {
+	if r.RemoteTrusted.IsDisable(false) {
 		return &RemoteTrustCompileConfig{
 			UseTrustedIPs: false,
 			TrustedIPs:    make([]string, 0),

+ 11 - 6
src/mainfunc/v1.go → src/mainfunc/huanproxy/v1.go

@@ -1,4 +1,4 @@
-package mainfunc
+package huanproxy
 
 import (
 	"errors"
@@ -12,6 +12,7 @@ import (
 	"github.com/SongZihuan/huan-proxy/src/server/httpsserver"
 	"github.com/SongZihuan/huan-proxy/src/utils"
 	"os"
+	"time"
 )
 
 func MainV1() int {
@@ -68,18 +69,22 @@ func MainV1() int {
 
 	ser := server.NewHuanProxyServer()
 
-	httpschan := make(chan error)
-	httpchan := make(chan error)
+	httpErrorChan := make(chan error)
+	httpsErrorChan := make(chan error)
 
-	err = ser.Run(httpschan, httpchan)
+	err = ser.Run(httpErrorChan, httpsErrorChan)
 	if err != nil {
 		return utils.ExitByErrorMsg(fmt.Sprintf("run http/https error: %s", err.Error()))
 	}
+	defer func() {
+		_ = ser.Stop()
+		time.Sleep(1 * time.Second)
+	}()
 
 	select {
 	case <-config.GetSignalChan():
 		return 0
-	case err := <-httpchan:
+	case err := <-httpErrorChan:
 		if errors.Is(err, httpserver.ServerStop) {
 			return 0
 		} else if err != nil {
@@ -87,7 +92,7 @@ func MainV1() int {
 		} else {
 			return 0
 		}
-	case err := <-httpschan:
+	case err := <-httpsErrorChan:
 		if errors.Is(err, httpsserver.ServerStop) {
 			return 0
 		} else if err != nil {

+ 53 - 7
src/server/httpserver/server.go

@@ -1,11 +1,15 @@
 package httpserver
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"github.com/SongZihuan/huan-proxy/src/config"
 	"github.com/SongZihuan/huan-proxy/src/logger"
+	"github.com/pires/go-proxyproto"
+	"net"
 	"net/http"
+	"time"
 )
 
 var ServerStop = fmt.Errorf("https server stop")
@@ -38,18 +42,60 @@ func (s *HTTPServer) LoadHttp() error {
 	return nil
 }
 
-func (s *HTTPServer) RunHttp(_httpschan chan error) chan error {
-	go func(httpschan chan error) {
+func (s *HTTPServer) StopHttp() error {
+	if s.server == nil {
+		return nil
+	}
+
+	ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Minute)
+	defer cancelFunc()
+
+	err := s.server.Shutdown(ctx)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (s *HTTPServer) RunHttp(httpErrorChan chan error) {
+	go func() {
+		listener, err := s.getListener()
+		if err != nil {
+			httpErrorChan <- fmt.Errorf("listen fail")
+			return
+		}
+		defer func() {
+			_ = listener.Close()
+		}()
+
 		logger.Infof("start http server in %s", s.cfg.Address)
-		err := s.server.ListenAndServe()
+		err = s.server.Serve(listener)
 		if err != nil && errors.Is(err, http.ErrServerClosed) {
-			httpschan <- ServerStop
+			httpErrorChan <- ServerStop
 			return
 		} else if err != nil {
-			httpschan <- err
+			httpErrorChan <- err
 			return
 		}
-	}(_httpschan)
+	}()
+}
+
+func (s *HTTPServer) getListener() (net.Listener, error) {
+	tcpListener, err := net.Listen("tcp", s.server.Addr)
+	if err != nil {
+		return nil, fmt.Errorf("tcp listen on %s: %s\n", s.server.Addr, err.Error())
+	}
+
+	var proxyListener net.Listener
+	if s.cfg.ProxyProto.IsEnable(true) {
+		proxyListener = &proxyproto.Listener{
+			Listener:          tcpListener,
+			ReadHeaderTimeout: 10 * time.Second,
+		}
+	} else {
+		proxyListener = tcpListener
+	}
 
-	return _httpschan
+	return proxyListener, nil
 }

+ 127 - 60
src/server/httpsserver/server.go

@@ -10,6 +10,8 @@ import (
 	"github.com/SongZihuan/huan-proxy/src/certssl"
 	"github.com/SongZihuan/huan-proxy/src/config"
 	"github.com/SongZihuan/huan-proxy/src/logger"
+	"github.com/pires/go-proxyproto"
+	"net"
 	"net/http"
 	"sync"
 	"time"
@@ -28,14 +30,14 @@ type HTTPSServer struct {
 }
 
 func NewHTTPSServer(handler http.Handler) *HTTPSServer {
-	httpscfg := config.GetConfig().Https
+	httpsCfg := config.GetConfig().Https
 
-	if httpscfg.Address == "" {
+	if httpsCfg.Address == "" {
 		return nil
 	}
 
 	return &HTTPSServer{
-		cfg:     &httpscfg,
+		cfg:     &httpsCfg,
 		server:  nil,
 		handler: handler,
 	}
@@ -61,6 +63,22 @@ func (s *HTTPSServer) LoadHttps() error {
 	return nil
 }
 
+func (s *HTTPSServer) StopHttps() error {
+	if s.server == nil {
+		return nil
+	}
+
+	ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Minute)
+	defer cancelFunc()
+
+	err := s.server.Shutdown(ctx)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
 func (s *HTTPSServer) reloadHttps() error {
 	if s.key == nil || s.cert == nil || s.cacert == nil {
 		return fmt.Errorf("init https server error: get key and cert error, return nil, unknown reason")
@@ -70,6 +88,22 @@ func (s *HTTPSServer) reloadHttps() error {
 		return fmt.Errorf("init https server error: get cert.raw error, return nil, unknown reason")
 	}
 
+	s.server = &http.Server{
+		Addr:    s.cfg.Address,
+		Handler: s.handler,
+	}
+	return nil
+}
+
+func (s *HTTPSServer) getListener() (net.Listener, error) {
+	if s.key == nil || s.cert == nil || s.cacert == nil {
+		return nil, fmt.Errorf("init https server error: get key and cert error, return nil, unknown reason")
+	}
+
+	if s.cert.Raw == nil || len(s.cert.Raw) == 0 || s.cacert.Raw == nil || len(s.cacert.Raw) == 0 {
+		return nil, fmt.Errorf("init https server error: get cert.raw error, return nil, unknown reason")
+	}
+
 	tlsConfig := &tls.Config{
 		Certificates: []tls.Certificate{{
 			Certificate: [][]byte{s.cert.Raw, s.cacert.Raw}, // Raw包含 DER 编码的证书
@@ -79,87 +113,120 @@ func (s *HTTPSServer) reloadHttps() error {
 		MinVersion: tls.VersionTLS12,
 	}
 
-	s.server = &http.Server{
-		Addr:      s.cfg.Address,
-		Handler:   s.handler,
-		TLSConfig: tlsConfig,
+	tcpListener, err := net.Listen("tcp", s.server.Addr)
+	if err != nil {
+		return nil, fmt.Errorf("tcp listen on %s: %s\n", s.server.Addr, err.Error())
 	}
 
-	return nil
+	var proxyListener net.Listener
+	if s.cfg.ProxyProto.IsEnable(true) {
+		proxyListener = &proxyproto.Listener{
+			Listener:          tcpListener,
+			ReadHeaderTimeout: 10 * time.Second,
+		}
+	} else {
+		proxyListener = tcpListener
+	}
+
+	tlsListener := tls.NewListener(proxyListener, tlsConfig)
+
+	return tlsListener, nil
 }
 
-func (s *HTTPSServer) RunHttps(_httpschan chan error) chan error {
-	_watchstopchan := make(chan bool)
+func (s *HTTPSServer) RunHttps(httpsErrorChan chan error) {
+	watchStopChan := make(chan bool)
 
-	s.watchCertificate(_watchstopchan)
+	s.watchCertificate(watchStopChan)
 
-	go func(httpschan chan error, watchstopchan chan bool) {
+	go func() {
 		defer func() {
-			watchstopchan <- true
+			close(watchStopChan)
 		}()
-	ListenCycle:
+
+		defer func() {
+			s.server = nil
+		}()
+
 		for {
-			logger.Infof("start https server in %s", s.cfg.Address)
-			err := s.server.ListenAndServeTLS("", "")
-			if err != nil && errors.Is(err, http.ErrServerClosed) {
-				if s.reloadMutex.TryLock() {
-					s.reloadMutex.Unlock()
-					_httpschan <- ServerStop
-					return
+			res := func() bool {
+				listener, err := s.getListener()
+				if err != nil {
+					httpsErrorChan <- fmt.Errorf("listen fail")
+					return true
 				}
-				s.reloadMutex.Lock()
-				s.reloadMutex.Unlock() // 等待证书更换完毕
-				continue ListenCycle
-			} else if err != nil {
-				_httpschan <- fmt.Errorf("https server error: %s", err.Error())
+				defer func() {
+					_ = listener.Close()
+				}()
+
+				logger.Infof("start https server in %s", s.cfg.Address)
+				err = s.server.Serve(listener)
+				if err != nil && errors.Is(err, http.ErrServerClosed) {
+					if s.reloadMutex.TryLock() {
+						s.reloadMutex.Unlock()
+						httpsErrorChan <- ServerStop
+						return true
+					}
+
+					s.reloadMutex.Lock()
+					s.reloadMutex.Unlock() // 等待证书更换完毕
+					return false
+				} else if err != nil {
+					httpsErrorChan <- fmt.Errorf("https server error: %s", err.Error())
+					return true
+				}
+
+				return false
+			}()
+			if res {
 				return
 			}
 		}
-	}(_httpschan, _watchstopchan)
-
-	return _httpschan
+	}()
 }
 
-func (s *HTTPSServer) watchCertificate(stopchan chan bool) {
-	newchan := make(chan certssl.NewCert)
+func (s *HTTPSServer) watchCertificate(stopChan chan bool) {
+	newCertChan := make(chan certssl.NewCert)
 
 	go func() {
-		err := certssl.WatchCertificate(s.cfg.SSLCertDir, s.cfg.SSLEmail, s.cfg.AliyunDNSAccessKey, s.cfg.AliyunDNSAccessSecret, s.cfg.SSLDomain, s.cert, stopchan, newchan)
+		err := certssl.WatchCertificate(s.cfg.SSLCertDir, s.cfg.SSLEmail, s.cfg.AliyunDNSAccessKey, s.cfg.AliyunDNSAccessSecret, s.cfg.SSLDomain, s.cert, stopChan, newCertChan)
 		if err != nil {
 			logger.Errorf("watch https cert server error: %s", err.Error())
 		}
 	}()
 
 	go func() {
-		select {
-		case res := <-newchan:
-			if res.Certificate == nil && res.PrivateKey == nil && res.Error == nil {
-				close(newchan)
-				return
-			} else if res.Error != nil {
-				logger.Errorf("https cert reload server error: %s", res.Error.Error())
-			} else if res.PrivateKey != nil && res.Certificate != nil && res.IssuerCertificate != nil {
-				func() {
-					s.reloadMutex.Lock()
-					defer s.reloadMutex.Unlock()
-
-					ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
-					defer cancel()
-
-					err := s.server.Shutdown(ctx)
-					if err != nil {
-						logger.Errorf("https server reload shutdown error: %s", err.Error())
-					}
+		defer close(newCertChan)
 
-					s.key = res.PrivateKey
-					s.cert = res.Certificate
-					s.cacert = res.IssuerCertificate
-
-					err = s.reloadHttps()
-					if err != nil {
-						logger.Errorf("https server reload init error: %s", err.Error())
-					}
-				}()
+		for {
+			select {
+			case <-stopChan:
+				return
+			case res := <-newCertChan:
+				if res.Error != nil {
+					logger.Errorf("https cert reload server error: %s", res.Error.Error())
+				} else if res.PrivateKey != nil && res.Certificate != nil && res.IssuerCertificate != nil {
+					func() {
+						s.reloadMutex.Lock()
+						defer s.reloadMutex.Unlock()
+
+						ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+						defer cancel()
+
+						err := s.server.Shutdown(ctx)
+						if err != nil {
+							logger.Errorf("https server reload shutdown error: %s", err.Error())
+						}
+
+						s.key = res.PrivateKey
+						s.cert = res.Certificate
+						s.cacert = res.IssuerCertificate
+
+						err = s.reloadHttps()
+						if err != nil {
+							logger.Errorf("https server reload init error: %s", err.Error())
+						}
+					}()
+				}
 			}
 		}
 	}()

+ 15 - 3
src/server/main.go

@@ -34,14 +34,14 @@ func NewHuanProxyServer() *HuanProxyServer {
 	return res
 }
 
-func (s *HuanProxyServer) Run(httpschan chan error, httpchan chan error) (err error) {
+func (s *HuanProxyServer) Run(httpErrorChan chan error, httpsErrorChan chan error) (err error) {
 	if s.https != nil {
 		err := s.https.LoadHttps()
 		if err != nil {
 			return err
 		}
 
-		s.https.RunHttps(httpschan)
+		s.https.RunHttps(httpsErrorChan)
 	}
 
 	if s.http != nil {
@@ -50,7 +50,19 @@ func (s *HuanProxyServer) Run(httpschan chan error, httpchan chan error) (err er
 			return err
 		}
 
-		s.http.RunHttp(httpchan)
+		s.http.RunHttp(httpErrorChan)
+	}
+
+	return nil
+}
+
+func (s *HuanProxyServer) Stop() (err error) {
+	if s.http != nil {
+		_ = s.http.StopHttp()
+	}
+
+	if s.https != nil {
+		_ = s.https.StopHttps()
 	}
 
 	return nil