server.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package httpsslserver
  2. import (
  3. "context"
  4. "crypto"
  5. "crypto/tls"
  6. "crypto/x509"
  7. "errors"
  8. "fmt"
  9. "github.com/SongZihuan/Http-Demo/src/certssl"
  10. "github.com/SongZihuan/Http-Demo/src/engine"
  11. "github.com/SongZihuan/Http-Demo/src/flagparser"
  12. "net/http"
  13. "sync"
  14. "time"
  15. )
  16. var HttpSSLServer *http.Server = nil
  17. var HttpSSLAddress string
  18. var HttpSSLDomain string
  19. var HttpSSLEmail string
  20. var HttpSSLCertDir string
  21. var PrivateKey crypto.PrivateKey
  22. var Certificate *x509.Certificate
  23. var ErrStop = fmt.Errorf("http server error")
  24. var ReloadMutex sync.Mutex
  25. func InitHttpSSLServer() (err error) {
  26. HttpSSLAddress = flagparser.HttpsAddress
  27. HttpSSLDomain = flagparser.HttpsDomain
  28. HttpSSLEmail = flagparser.HttpsEmail
  29. HttpSSLCertDir = flagparser.HttpsCertDir
  30. PrivateKey, Certificate, err = certssl.GetCertificateAndPrivateKey(HttpSSLCertDir, HttpSSLEmail, HttpSSLAddress, HttpSSLDomain)
  31. if err != nil {
  32. return err
  33. }
  34. return initHttpSSLServer()
  35. }
  36. func initHttpSSLServer() (err error) {
  37. tlsConfig := &tls.Config{
  38. Certificates: []tls.Certificate{{
  39. Certificate: [][]byte{Certificate.Raw}, // Raw包含 DER 编码的证书
  40. PrivateKey: PrivateKey,
  41. Leaf: Certificate,
  42. }},
  43. }
  44. HttpSSLServer = &http.Server{
  45. Addr: HttpSSLAddress,
  46. Handler: engine.Engine,
  47. TLSConfig: tlsConfig,
  48. }
  49. return nil
  50. }
  51. func RunServer() error {
  52. stopchan := make(chan bool)
  53. WatchCert(stopchan)
  54. err := runServer()
  55. stopchan <- true
  56. return err
  57. }
  58. func runServer() error {
  59. fmt.Printf("https server start at %s\n", HttpSSLAddress)
  60. ListenCycle:
  61. for {
  62. err := HttpSSLServer.ListenAndServeTLS("", "")
  63. if err != nil && errors.Is(err, http.ErrServerClosed) {
  64. if ReloadMutex.TryLock() {
  65. ReloadMutex.Unlock()
  66. return ErrStop
  67. }
  68. ReloadMutex.Lock()
  69. ReloadMutex.Unlock() // 等待证书更换完毕
  70. continue ListenCycle
  71. } else if err != nil {
  72. return err
  73. }
  74. }
  75. }
  76. func WatchCert(stopchan chan bool) {
  77. newchan := make(chan certssl.NewCert)
  78. go func() {
  79. err := certssl.WatchCertificateAndPrivateKey(HttpSSLCertDir, HttpSSLEmail, HttpSSLAddress, HttpSSLDomain, PrivateKey, Certificate, stopchan, newchan)
  80. if err != nil {
  81. fmt.Printf("watch cert error: %s", err.Error())
  82. }
  83. }()
  84. go func() {
  85. select {
  86. case res := <-newchan:
  87. if res.Certificate == nil && res.PrivateKey == nil && res.Error == nil {
  88. close(newchan)
  89. return
  90. } else if res.Error != nil {
  91. fmt.Printf("watch cert error: %s", res.Error.Error())
  92. } else if res.PrivateKey == nil && res.Certificate == nil {
  93. func() {
  94. ReloadMutex.Lock()
  95. defer ReloadMutex.Unlock()
  96. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  97. defer cancel()
  98. err := HttpSSLServer.Shutdown(ctx)
  99. if err != nil {
  100. fmt.Printf("reload error: %s", err.Error())
  101. }
  102. PrivateKey = res.PrivateKey
  103. Certificate = res.Certificate
  104. }()
  105. }
  106. }
  107. }()
  108. }