|
@@ -10,6 +10,8 @@ import (
|
|
"github.com/SongZihuan/huan-proxy/src/certssl"
|
|
"github.com/SongZihuan/huan-proxy/src/certssl"
|
|
"github.com/SongZihuan/huan-proxy/src/config"
|
|
"github.com/SongZihuan/huan-proxy/src/config"
|
|
"github.com/SongZihuan/huan-proxy/src/logger"
|
|
"github.com/SongZihuan/huan-proxy/src/logger"
|
|
|
|
+ "github.com/pires/go-proxyproto"
|
|
|
|
+ "net"
|
|
"net/http"
|
|
"net/http"
|
|
"sync"
|
|
"sync"
|
|
"time"
|
|
"time"
|
|
@@ -28,14 +30,14 @@ type HTTPSServer struct {
|
|
}
|
|
}
|
|
|
|
|
|
func NewHTTPSServer(handler http.Handler) *HTTPSServer {
|
|
func NewHTTPSServer(handler http.Handler) *HTTPSServer {
|
|
- httpscfg := config.GetConfig().Https
|
|
|
|
|
|
+ httpsCfg := config.GetConfig().Https
|
|
|
|
|
|
- if httpscfg.Address == "" {
|
|
|
|
|
|
+ if httpsCfg.Address == "" {
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
return &HTTPSServer{
|
|
return &HTTPSServer{
|
|
- cfg: &httpscfg,
|
|
|
|
|
|
+ cfg: &httpsCfg,
|
|
server: nil,
|
|
server: nil,
|
|
handler: handler,
|
|
handler: handler,
|
|
}
|
|
}
|
|
@@ -61,6 +63,22 @@ func (s *HTTPSServer) LoadHttps() error {
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func (s *HTTPSServer) StopHttps() error {
|
|
|
|
+ if s.server == nil {
|
|
|
|
+ return nil
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Minute)
|
|
|
|
+ defer cancelFunc()
|
|
|
|
+
|
|
|
|
+ err := s.server.Shutdown(ctx)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return err
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return nil
|
|
|
|
+}
|
|
|
|
+
|
|
func (s *HTTPSServer) reloadHttps() error {
|
|
func (s *HTTPSServer) reloadHttps() error {
|
|
if s.key == nil || s.cert == nil || s.cacert == nil {
|
|
if s.key == nil || s.cert == nil || s.cacert == nil {
|
|
return fmt.Errorf("init https server error: get key and cert error, return nil, unknown reason")
|
|
return fmt.Errorf("init https server error: get key and cert error, return nil, unknown reason")
|
|
@@ -70,6 +88,22 @@ func (s *HTTPSServer) reloadHttps() error {
|
|
return fmt.Errorf("init https server error: get cert.raw error, return nil, unknown reason")
|
|
return fmt.Errorf("init https server error: get cert.raw error, return nil, unknown reason")
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ s.server = &http.Server{
|
|
|
|
+ Addr: s.cfg.Address,
|
|
|
|
+ Handler: s.handler,
|
|
|
|
+ }
|
|
|
|
+ return nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (s *HTTPSServer) getListener() (net.Listener, error) {
|
|
|
|
+ if s.key == nil || s.cert == nil || s.cacert == nil {
|
|
|
|
+ return nil, fmt.Errorf("init https server error: get key and cert error, return nil, unknown reason")
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if s.cert.Raw == nil || len(s.cert.Raw) == 0 || s.cacert.Raw == nil || len(s.cacert.Raw) == 0 {
|
|
|
|
+ return nil, fmt.Errorf("init https server error: get cert.raw error, return nil, unknown reason")
|
|
|
|
+ }
|
|
|
|
+
|
|
tlsConfig := &tls.Config{
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{{
|
|
Certificates: []tls.Certificate{{
|
|
Certificate: [][]byte{s.cert.Raw, s.cacert.Raw}, // Raw包含 DER 编码的证书
|
|
Certificate: [][]byte{s.cert.Raw, s.cacert.Raw}, // Raw包含 DER 编码的证书
|
|
@@ -79,87 +113,120 @@ func (s *HTTPSServer) reloadHttps() error {
|
|
MinVersion: tls.VersionTLS12,
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
}
|
|
|
|
|
|
- s.server = &http.Server{
|
|
|
|
- Addr: s.cfg.Address,
|
|
|
|
- Handler: s.handler,
|
|
|
|
- TLSConfig: tlsConfig,
|
|
|
|
|
|
+ tcpListener, err := net.Listen("tcp", s.server.Addr)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, fmt.Errorf("tcp listen on %s: %s\n", s.server.Addr, err.Error())
|
|
}
|
|
}
|
|
|
|
|
|
- return nil
|
|
|
|
|
|
+ var proxyListener net.Listener
|
|
|
|
+ if s.cfg.ProxyProto.IsEnable(true) {
|
|
|
|
+ proxyListener = &proxyproto.Listener{
|
|
|
|
+ Listener: tcpListener,
|
|
|
|
+ ReadHeaderTimeout: 10 * time.Second,
|
|
|
|
+ }
|
|
|
|
+ } else {
|
|
|
|
+ proxyListener = tcpListener
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ tlsListener := tls.NewListener(proxyListener, tlsConfig)
|
|
|
|
+
|
|
|
|
+ return tlsListener, nil
|
|
}
|
|
}
|
|
|
|
|
|
-func (s *HTTPSServer) RunHttps(_httpschan chan error) chan error {
|
|
|
|
- _watchstopchan := make(chan bool)
|
|
|
|
|
|
+func (s *HTTPSServer) RunHttps(httpsErrorChan chan error) {
|
|
|
|
+ watchStopChan := make(chan bool)
|
|
|
|
|
|
- s.watchCertificate(_watchstopchan)
|
|
|
|
|
|
+ s.watchCertificate(watchStopChan)
|
|
|
|
|
|
- go func(httpschan chan error, watchstopchan chan bool) {
|
|
|
|
|
|
+ go func() {
|
|
defer func() {
|
|
defer func() {
|
|
- watchstopchan <- true
|
|
|
|
|
|
+ close(watchStopChan)
|
|
}()
|
|
}()
|
|
- ListenCycle:
|
|
|
|
|
|
+
|
|
|
|
+ defer func() {
|
|
|
|
+ s.server = nil
|
|
|
|
+ }()
|
|
|
|
+
|
|
for {
|
|
for {
|
|
- logger.Infof("start https server in %s", s.cfg.Address)
|
|
|
|
- err := s.server.ListenAndServeTLS("", "")
|
|
|
|
- if err != nil && errors.Is(err, http.ErrServerClosed) {
|
|
|
|
- if s.reloadMutex.TryLock() {
|
|
|
|
- s.reloadMutex.Unlock()
|
|
|
|
- _httpschan <- ServerStop
|
|
|
|
- return
|
|
|
|
|
|
+ res := func() bool {
|
|
|
|
+ listener, err := s.getListener()
|
|
|
|
+ if err != nil {
|
|
|
|
+ httpsErrorChan <- fmt.Errorf("listen fail")
|
|
|
|
+ return true
|
|
}
|
|
}
|
|
- s.reloadMutex.Lock()
|
|
|
|
- s.reloadMutex.Unlock() // 等待证书更换完毕
|
|
|
|
- continue ListenCycle
|
|
|
|
- } else if err != nil {
|
|
|
|
- _httpschan <- fmt.Errorf("https server error: %s", err.Error())
|
|
|
|
|
|
+ defer func() {
|
|
|
|
+ _ = listener.Close()
|
|
|
|
+ }()
|
|
|
|
+
|
|
|
|
+ logger.Infof("start https server in %s", s.cfg.Address)
|
|
|
|
+ err = s.server.Serve(listener)
|
|
|
|
+ if err != nil && errors.Is(err, http.ErrServerClosed) {
|
|
|
|
+ if s.reloadMutex.TryLock() {
|
|
|
|
+ s.reloadMutex.Unlock()
|
|
|
|
+ httpsErrorChan <- ServerStop
|
|
|
|
+ return true
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ s.reloadMutex.Lock()
|
|
|
|
+ s.reloadMutex.Unlock() // 等待证书更换完毕
|
|
|
|
+ return false
|
|
|
|
+ } else if err != nil {
|
|
|
|
+ httpsErrorChan <- fmt.Errorf("https server error: %s", err.Error())
|
|
|
|
+ return true
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return false
|
|
|
|
+ }()
|
|
|
|
+ if res {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
- }(_httpschan, _watchstopchan)
|
|
|
|
-
|
|
|
|
- return _httpschan
|
|
|
|
|
|
+ }()
|
|
}
|
|
}
|
|
|
|
|
|
-func (s *HTTPSServer) watchCertificate(stopchan chan bool) {
|
|
|
|
- newchan := make(chan certssl.NewCert)
|
|
|
|
|
|
+func (s *HTTPSServer) watchCertificate(stopChan chan bool) {
|
|
|
|
+ newCertChan := make(chan certssl.NewCert)
|
|
|
|
|
|
go func() {
|
|
go func() {
|
|
- err := certssl.WatchCertificate(s.cfg.SSLCertDir, s.cfg.SSLEmail, s.cfg.AliyunDNSAccessKey, s.cfg.AliyunDNSAccessSecret, s.cfg.SSLDomain, s.cert, stopchan, newchan)
|
|
|
|
|
|
+ err := certssl.WatchCertificate(s.cfg.SSLCertDir, s.cfg.SSLEmail, s.cfg.AliyunDNSAccessKey, s.cfg.AliyunDNSAccessSecret, s.cfg.SSLDomain, s.cert, stopChan, newCertChan)
|
|
if err != nil {
|
|
if err != nil {
|
|
logger.Errorf("watch https cert server error: %s", err.Error())
|
|
logger.Errorf("watch https cert server error: %s", err.Error())
|
|
}
|
|
}
|
|
}()
|
|
}()
|
|
|
|
|
|
go func() {
|
|
go func() {
|
|
- select {
|
|
|
|
- case res := <-newchan:
|
|
|
|
- if res.Certificate == nil && res.PrivateKey == nil && res.Error == nil {
|
|
|
|
- close(newchan)
|
|
|
|
- return
|
|
|
|
- } else if res.Error != nil {
|
|
|
|
- logger.Errorf("https cert reload server error: %s", res.Error.Error())
|
|
|
|
- } else if res.PrivateKey != nil && res.Certificate != nil && res.IssuerCertificate != nil {
|
|
|
|
- func() {
|
|
|
|
- s.reloadMutex.Lock()
|
|
|
|
- defer s.reloadMutex.Unlock()
|
|
|
|
-
|
|
|
|
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
|
- defer cancel()
|
|
|
|
-
|
|
|
|
- err := s.server.Shutdown(ctx)
|
|
|
|
- if err != nil {
|
|
|
|
- logger.Errorf("https server reload shutdown error: %s", err.Error())
|
|
|
|
- }
|
|
|
|
|
|
+ defer close(newCertChan)
|
|
|
|
|
|
- s.key = res.PrivateKey
|
|
|
|
- s.cert = res.Certificate
|
|
|
|
- s.cacert = res.IssuerCertificate
|
|
|
|
-
|
|
|
|
- err = s.reloadHttps()
|
|
|
|
- if err != nil {
|
|
|
|
- logger.Errorf("https server reload init error: %s", err.Error())
|
|
|
|
- }
|
|
|
|
- }()
|
|
|
|
|
|
+ for {
|
|
|
|
+ select {
|
|
|
|
+ case <-stopChan:
|
|
|
|
+ return
|
|
|
|
+ case res := <-newCertChan:
|
|
|
|
+ if res.Error != nil {
|
|
|
|
+ logger.Errorf("https cert reload server error: %s", res.Error.Error())
|
|
|
|
+ } else if res.PrivateKey != nil && res.Certificate != nil && res.IssuerCertificate != nil {
|
|
|
|
+ func() {
|
|
|
|
+ s.reloadMutex.Lock()
|
|
|
|
+ defer s.reloadMutex.Unlock()
|
|
|
|
+
|
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
|
+ defer cancel()
|
|
|
|
+
|
|
|
|
+ err := s.server.Shutdown(ctx)
|
|
|
|
+ if err != nil {
|
|
|
|
+ logger.Errorf("https server reload shutdown error: %s", err.Error())
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ s.key = res.PrivateKey
|
|
|
|
+ s.cert = res.Certificate
|
|
|
|
+ s.cacert = res.IssuerCertificate
|
|
|
|
+
|
|
|
|
+ err = s.reloadHttps()
|
|
|
|
+ if err != nil {
|
|
|
|
+ logger.Errorf("https server reload init error: %s", err.Error())
|
|
|
|
+ }
|
|
|
|
+ }()
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}()
|