Browse Source

引入 Context 结构体并更新相关函数

引入了新的 Context 结构体来封装请求和响应的相关信息,并更新了多个函数以使用该结构体,从而提高代码的可维护性和一致性。同时,对文件读取、CORS 处理以及错误处理等逻辑进行了相应的调整。
SongZihuan 3 months ago
parent
commit
d5b86d8f4f

+ 21 - 12
src/server/abort.go

@@ -2,26 +2,35 @@ package server
 
 import "net/http"
 
-func (s *HuanProxyServer) abortForbidden(w http.ResponseWriter) {
-	w.WriteHeader(http.StatusForbidden)
+func (s *HuanProxyServer) abort(ctx *Context, code int) {
+	if ctx.Abort {
+		return
+	}
+
+	ctx.Writer.WriteHeader(code)
+	ctx.Abort = true
+}
+
+func (s *HuanProxyServer) abortForbidden(ctx *Context) {
+	s.abort(ctx, http.StatusForbidden)
 }
 
-func (s *HuanProxyServer) abortNotFound(w http.ResponseWriter) {
-	w.WriteHeader(http.StatusNotFound)
+func (s *HuanProxyServer) abortNotFound(ctx *Context) {
+	s.abort(ctx, http.StatusNotFound)
 }
 
-func (s *HuanProxyServer) abortNotAcceptable(w http.ResponseWriter) {
-	w.WriteHeader(http.StatusNotAcceptable)
+func (s *HuanProxyServer) abortNotAcceptable(ctx *Context) {
+	s.abort(ctx, http.StatusNotAcceptable)
 }
 
-func (s *HuanProxyServer) abortMethodNotAllowed(w http.ResponseWriter) {
-	w.WriteHeader(http.StatusMethodNotAllowed)
+func (s *HuanProxyServer) abortMethodNotAllowed(ctx *Context) {
+	s.abort(ctx, http.StatusMethodNotAllowed)
 }
 
-func (s *HuanProxyServer) abortServerError(w http.ResponseWriter) {
-	w.WriteHeader(http.StatusInternalServerError)
+func (s *HuanProxyServer) abortServerError(ctx *Context) {
+	s.abort(ctx, http.StatusInternalServerError)
 }
 
-func (s *HuanProxyServer) abortNoContent(w http.ResponseWriter) {
-	w.WriteHeader(http.StatusNoContent)
+func (s *HuanProxyServer) abortNoContent(ctx *Context) {
+	s.abort(ctx, http.StatusNoContent)
 }

+ 21 - 0
src/server/context.go

@@ -0,0 +1,21 @@
+package server
+
+import (
+	"github.com/SongZihuan/huan-proxy/src/config/rulescompile"
+	"net/http"
+)
+
+type Context struct {
+	Abort   bool
+	Writer  writer
+	Request *http.Request
+	Rule    *rulescompile.RuleCompileConfig
+}
+
+func NewContext(rule *rulescompile.RuleCompileConfig, w http.ResponseWriter, r *http.Request) *Context {
+	return &Context{
+		Writer:  NewWriter(w),
+		Request: r,
+		Rule:    rule,
+	}
+}

+ 9 - 9
src/server/cors.go

@@ -6,30 +6,30 @@ import (
 	"net/http"
 )
 
-func (s *HuanProxyServer) cors(corsRule *corscompile.CorsCompileConfig, w http.ResponseWriter, r *http.Request) bool {
+func (s *HuanProxyServer) cors(corsRule *corscompile.CorsCompileConfig, ctx *Context) bool {
 	if corsRule.Ignore {
-		if r.Method == http.MethodOptions {
-			s.abortMethodNotAllowed(w)
+		if ctx.Request.Method == http.MethodOptions {
+			s.abortMethodNotAllowed(ctx)
 			return false
 		} else {
 			return true
 		}
 	}
 
-	origin := r.Header.Get("Origin")
+	origin := ctx.Request.Header.Get("Origin")
 	if origin == "" {
-		s.abortForbidden(w)
+		s.abortForbidden(ctx)
 		return false
 	}
 
 	if !corsRule.InOriginList(origin) {
-		s.abortForbidden(w)
+		s.abortForbidden(ctx)
 		return false
 	}
 
-	w.Header().Set("Access-Control-Allow-Origin", origin)
-	w.Header().Set("Access-Control-Allow-Methods", "GET,OPTIONS")
-	w.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", corsRule.MaxAgeSec))
+	ctx.Writer.Header().Set("Access-Control-Allow-Origin", origin)
+	ctx.Writer.Header().Set("Access-Control-Allow-Methods", "GET,OPTIONS")
+	ctx.Writer.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", corsRule.MaxAgeSec))
 
 	return true
 }

+ 10 - 11
src/server/file.go

@@ -1,41 +1,40 @@
 package server
 
 import (
-	"github.com/SongZihuan/huan-proxy/src/config/rulescompile"
 	"github.com/SongZihuan/huan-proxy/src/utils"
 	"github.com/gabriel-vasile/mimetype"
 	"net/http"
 	"os"
 )
 
-func (s *HuanProxyServer) fileServer(rule *rulescompile.RuleCompileConfig, w http.ResponseWriter, r *http.Request) {
-	if !s.cors(rule.File.Cors, w, r) {
+func (s *HuanProxyServer) fileServer(ctx *Context) {
+	if !s.cors(ctx.Rule.File.Cors, ctx) {
 		return
 	}
 
 	if r.Method != http.MethodGet {
-		s.abortMethodNotAllowed(w)
+		s.abortMethodNotAllowed(ctx)
 		return
 	}
 
-	file, err := os.ReadFile(rule.File.Path)
+	file, err := os.ReadFile(ctx.Rule.File.Path)
 	if err != nil {
-		s.abortServerError(w)
+		s.abortServerError(ctx)
 		return
 	}
 
 	mimeType := mimetype.Detect(file)
 	accept := r.Header.Get("Accept")
 	if !utils.AcceptMimeType(accept, mimeType.String()) {
-		s.abortNotAcceptable(w)
+		s.abortNotAcceptable(ctx)
 		return
 	}
 
-	_, err = w.Write(file)
+	_, err = ctx.Writer.Write(file)
 	if err != nil {
-		s.abortServerError(w)
+		s.abortServerError(ctx)
 		return
 	}
-	w.Header().Set("Content-Type", mimeType.String())
-	s.statusOK(w)
+	ctx.Writer.Header().Set("Content-Type", mimeType.String())
+	s.statusOK(ctx)
 }

+ 2 - 2
src/server/loggerserver.go

@@ -154,8 +154,8 @@ func (s *HuanProxyServer) LoggerServerHTTP(_w http.ResponseWriter, r *http.Reque
 
 	param.RemoteAddr = r.RemoteAddr
 	param.Method = r.Method
-	param.StatusCode = w.Status
-	param.BodySize = w.Size
+	param.StatusCode = w.status
+	param.BodySize = w.Size()
 
 	if raw != "" {
 		path = path + "?" + raw

+ 5 - 7
src/server/proxytrust.go

@@ -1,23 +1,21 @@
 package server
 
 import (
-	"github.com/SongZihuan/huan-proxy/src/config/rulescompile"
 	"github.com/SongZihuan/huan-proxy/src/utils"
 	"net"
-	"net/http"
 )
 
-func (s *HuanProxyServer) checkProxyTrust(rule *rulescompile.RuleCompileConfig, w http.ResponseWriter, r *http.Request) bool {
-	if !rule.UseTrustedIPs {
+func (s *HuanProxyServer) checkProxyTrust(ctx *Context) bool {
+	if !ctx.Rule.UseTrustedIPs {
 		return true
 	}
 
-	if r.RemoteAddr == "" {
+	if ctx.Request.RemoteAddr == "" {
 		s.abortForbidden(w)
 		return false
 	}
 
-	remoteIPStr, _, err := net.SplitHostPort(r.RemoteAddr)
+	remoteIPStr, _, err := net.SplitHostPort(ctx.Request.RemoteAddr)
 	if err != nil {
 		s.abortForbidden(w)
 		return false
@@ -29,7 +27,7 @@ func (s *HuanProxyServer) checkProxyTrust(rule *rulescompile.RuleCompileConfig,
 		return false
 	}
 
-	for _, t := range rule.TrustedIPs {
+	for _, t := range ctx.Rule.TrustedIPs {
 		if utils.ValidIPv4(t) || utils.ValidIPv6(t) {
 			trustIP := net.ParseIP(t)
 			if trustIP == nil {

+ 4 - 5
src/server/serverhttp.go

@@ -1,7 +1,6 @@
 package server
 
 import (
-	"fmt"
 	"github.com/SongZihuan/huan-proxy/src/config/rulescompile"
 	"net/http"
 )
@@ -15,14 +14,14 @@ func (s *HuanProxyServer) NormalServeHTTP(w http.ResponseWriter, r *http.Request
 				continue
 			}
 
-			if !s.checkProxyTrust(rule, w, r) {
+			ctx := NewContext(rule, w, r)
+
+			if !s.checkProxyTrust(ctx) {
 				return
 			}
 
-			fmt.Printf("rule.Type: %d\n", rule.Type)
-
 			if rule.Type == rulescompile.ProxyTypeFile {
-				s.fileServer(rule, w, r)
+				s.fileServer(ctx)
 				return
 			} else if rule.Type == rulescompile.ProxyTypeDir {
 				s.dirServer(rule, w, r)

+ 75 - 16
src/server/writer.go

@@ -1,30 +1,89 @@
 package server
 
-import "net/http"
+import (
+	"bytes"
+	"net/http"
+)
+
+type writer http.ResponseWriter
 
 type ResponseWriter struct {
-	http.ResponseWriter
-	Status int
-	Size   int64
+	writer
+	status  int
+	buffer  bytes.Buffer
+	written bool
+	header  http.Header
 }
 
-func NewWriter(w http.ResponseWriter) *ResponseWriter {
-	return &ResponseWriter{
-		ResponseWriter: w,
-		Status:         0,
+func NewWriter(w writer) *ResponseWriter {
+	res := &ResponseWriter{
+		writer: w,
+		status: 0,
+		header: make(http.Header, 10),
+	}
+
+	for n, h := range w.Header() {
+		nh := make([]string, 0, len(h))
+		copy(nh, h)
+		res.header[n] = nh
 	}
+
+	return res
+}
+
+func (r *ResponseWriter) Size() int64 {
+	return int64(r.buffer.Len())
 }
 
 func (r *ResponseWriter) Write(p []byte) (int, error) {
-	n, err := r.ResponseWriter.Write(p)
-	if err != nil {
-		return n, err
-	}
-	r.Size += int64(n)
-	return n, nil
+	return r.buffer.Write(p)
 }
 
 func (r *ResponseWriter) WriteHeader(statusCode int) {
-	r.Status = statusCode
-	r.ResponseWriter.WriteHeader(statusCode)
+	r.status = statusCode
+}
+
+func (r *ResponseWriter) Header() http.Header {
+	return r.header
+}
+
+func (r *ResponseWriter) Reset() {
+	r.status = 0
+	r.header = make(http.Header, 10)
+	r.buffer.Reset()
+	r.written = false
+}
+
+func (r *ResponseWriter) WriteToResponse() error {
+	if r.written {
+		return nil
+	}
+
+	_, err := r.writer.Write(r.buffer.Bytes())
+	if err != nil {
+		return err
+	}
+
+	r.writer.WriteHeader(r.status)
+
+	writerHeader := r.writer.Header()
+	for n, h := range r.header {
+		nh := make([]string, 0, len(h))
+		copy(nh, h)
+		writerHeader[n] = nh
+	}
+
+	delHeader := make([]string, 0, 10)
+	for n, _ := range writerHeader {
+		if _, ok := r.header[n]; !ok {
+			delHeader = append(delHeader, n)
+		}
+	}
+
+	for _, n := range delHeader {
+		delete(writerHeader, n)
+	}
+
+	r.written = true
+	return nil
 }