Преглед изворни кода

重构代码以使用上下文对象

将多个处理函数中的参数从 `http.ResponseWriter` 和 `*http.Request` 更改为自定义的 `Context` 对象,以便更好地管理和传递请求和响应信息。同时,移除了不再使用的 `respose.go` 文件,并对相关文件进行了调整以适应新的上下文对象结构。
SongZihuan пре 3 месеци
родитељ
комит
1718368583

+ 11 - 1
src/server/abort.go

@@ -7,8 +7,18 @@ func (s *HuanProxyServer) abort(ctx *Context, code int) {
 		return
 	}
 
-	ctx.Writer.WriteHeader(code)
 	ctx.Abort = true
+	w, ok := ctx.Writer.(*ResponseWriter)
+	if !ok {
+		return
+	}
+
+	err := w.Reset()
+	if err != nil {
+		ctx.Writer.WriteHeader(http.StatusInternalServerError)
+	} else {
+		ctx.Writer.WriteHeader(code)
+	}
 }
 
 func (s *HuanProxyServer) abortForbidden(ctx *Context) {

+ 59 - 52
src/server/api.go

@@ -2,59 +2,64 @@ package server
 
 import (
 	"fmt"
-	"github.com/SongZihuan/huan-proxy/src/config/rulescompile"
 	"github.com/SongZihuan/huan-proxy/src/config/rulescompile/actioncompile/rewritecompile"
 	"github.com/SongZihuan/huan-proxy/src/utils"
 	"net"
-	"net/http"
 	"strings"
 )
 
-func (s *HuanProxyServer) apiServer(rule *rulescompile.RuleCompileConfig, w http.ResponseWriter, r *http.Request) {
-	proxy := rule.Api.Server
+func (s *HuanProxyServer) apiServer(ctx *Context) {
+	proxy := ctx.Rule.Api.Server
 	if proxy == nil {
-		s.abortServerError(w)
+		s.abortServerError(ctx)
 		return
 	}
 
-	targetURL := rule.Api.TargetURL
-	r.URL.Scheme = targetURL.Scheme
-	r.URL.Host = targetURL.Host
+	targetURL := ctx.Rule.Api.TargetURL
+	ctx.ProxyRequest.URL.Scheme = targetURL.Scheme
+	ctx.ProxyRequest.URL.Host = targetURL.Host
 
-	s.processProxyHeader(r)
+	s.processProxyHeader(ctx)
 
-	r.URL.Path = s.apiRewrite(utils.ProcessURLPath(r.URL.Path), rule.Api.AddPath, rule.Api.SubPath, rule.Api.Rewrite)
+	ctx.ProxyRequest.URL.Path = s.apiRewrite(utils.ProcessURLPath(ctx.ProxyRequest.URL.Path), ctx.Rule.Api.AddPath, ctx.Rule.Api.SubPath, ctx.Rule.Api.Rewrite)
 
-	for _, h := range rule.Api.HeaderSet {
-		r.Header.Set(h.Header, h.Value)
+	for _, h := range ctx.Rule.Api.HeaderSet {
+		ctx.ProxyRequest.Header.Set(h.Header, h.Value)
 	}
 
-	for _, h := range rule.Api.HeaderAdd {
-		r.Header.Add(h.Header, h.Value)
+	for _, h := range ctx.Rule.Api.HeaderAdd {
+		ctx.ProxyRequest.Header.Add(h.Header, h.Value)
 	}
 
-	for _, h := range rule.Api.HeaderDel {
-		r.Header.Del(h.Header)
+	for _, h := range ctx.Rule.Api.HeaderDel {
+		ctx.ProxyRequest.Header.Del(h.Header)
 	}
 
-	query := r.URL.Query()
+	query := ctx.ProxyRequest.URL.Query()
 
-	for _, q := range rule.Api.QuerySet {
+	for _, q := range ctx.Rule.Api.QuerySet {
 		query.Set(q.Query, q.Value)
 	}
 
-	for _, q := range rule.Api.QueryAdd {
+	for _, q := range ctx.Rule.Api.QueryAdd {
 		query.Add(q.Query, q.Value)
 	}
 
-	for _, q := range rule.Api.QueryDel {
+	for _, q := range ctx.Rule.Api.QueryDel {
 		query.Del(q.Query)
 	}
 
-	r.URL.RawQuery = query.Encode()
+	ctx.ProxyRequest.URL.RawQuery = query.Encode()
 
-	s.writeViaHeader(rule, w, r)
-	proxy.ServeHTTP(w, r) // 反向代理
+	s.writeViaHeader(ctx)
+
+	req, err := ctx.ProxyWriteToHttpRRequest()
+	if err != nil {
+		s.abortServerError(ctx)
+		return
+	}
+
+	proxy.ServeHTTP(ctx.Writer, req) // 反向代理
 }
 
 func (s *HuanProxyServer) apiRewrite(srcpath string, prefix string, suffix string, rewrite *rewritecompile.RewriteCompileConfig) string {
@@ -75,12 +80,12 @@ func (s *HuanProxyServer) apiRewrite(srcpath string, prefix string, suffix strin
 	return srcpath
 }
 
-func (s *HuanProxyServer) processProxyHeader(r *http.Request) {
-	if r.RemoteAddr == "" {
+func (s *HuanProxyServer) processProxyHeader(ctx *Context) {
+	if ctx.Request.RemoteAddr() == "" {
 		return
 	}
 
-	remoteIPStr, _, err := net.SplitHostPort(r.RemoteAddr)
+	remoteIPStr, _, err := net.SplitHostPort(ctx.Request.RemoteAddr())
 	if err != nil {
 		return
 	}
@@ -90,24 +95,25 @@ func (s *HuanProxyServer) processProxyHeader(r *http.Request) {
 	var ProxyList, ForwardedList []string
 	var host, proto string
 
-	if r.Header.Get("Forwarded") != "" {
-		ProxyList, ForwardedList, host, proto = s.getProxyListForwarder(remoteIP, r)
-	} else if r.Header.Get("X-Forwarded-For") != "" {
-		ProxyList, ForwardedList, host, proto = s.getProxyListFromXForwardedFor(remoteIP, r)
+	if ctx.Request.Header().Get("Forwarded") != "" {
+		ProxyList, ForwardedList, host, proto = s.getProxyListForwarder(remoteIP, ctx.Request)
+	} else if ctx.Request.Header().Get("X-Forwarded-For") != "" {
+		ProxyList, ForwardedList, host, proto = s.getProxyListFromXForwardedFor(remoteIP, ctx.Request)
 	} else {
-		host = r.Header.Get("X-Forwarded-Host")
-		proto = r.Header.Get("X-Forwarded-Proto")
+		host = ctx.Request.Header().Get("X-Forwarded-Host")
+		proto = ctx.Request.Header().Get("X-Forwarded-Proto")
 
 		if host == "" {
-			host = r.Host
+			host = ctx.Request.Host()
 		}
 
 		host, _ = utils.SplitHostPort(host) // 去除host中的端口号
 
-		if proto == "" {
-			proto = "http"
-			if r.TLS != nil {
+		if proto == "http" || proto == "https" {
+			if ctx.Request.IsTLS() {
 				proto = "https"
+			} else {
+				proto = "http"
 			}
 		}
 
@@ -118,20 +124,20 @@ func (s *HuanProxyServer) processProxyHeader(r *http.Request) {
 			fmt.Sprintf("proto=%s", proto))
 	}
 
-	r.Header.Set("Forwarded", strings.Join(ForwardedList, ", "))
-	r.Header.Set("X-Forwarded-For", strings.Join(ProxyList, ", "))
-	r.Header.Set("X-Forwarded-Host", host)
-	r.Header.Set("X-Forwarded-Proto", proto)
+	ctx.ProxyRequest.Header.Set("Forwarded", strings.Join(ForwardedList, ", "))
+	ctx.ProxyRequest.Header.Set("X-Forwarded-For", strings.Join(ProxyList, ", "))
+	ctx.ProxyRequest.Header.Set("X-Forwarded-Host", host)
+	ctx.ProxyRequest.Header.Set("X-Forwarded-Proto", proto)
 }
 
-func (s *HuanProxyServer) getProxyListForwarder(remoteIP net.IP, r *http.Request) ([]string, []string, string, string) {
-	ForwardedList := strings.Split(r.Header.Get("Forwarded"), ",")
+func (s *HuanProxyServer) getProxyListForwarder(remoteIP net.IP, r *ReadOnlyRequest) ([]string, []string, string, string) {
+	ForwardedList := strings.Split(r.Header().Get("Forwarded"), ",")
 	ProxyList := make([]string, 0, len(ForwardedList)+1)
 	NewForwardedList := make([]string, 0, len(ForwardedList)+1)
 
-	host, _ := utils.SplitHostPort(r.Host) // 去除host中的端口号
+	host, _ := utils.SplitHostPort(r.Host()) // 去除host中的端口号
 	proto := "http"
-	if r.TLS != nil {
+	if r.IsTLS() {
 		proto = "https"
 	}
 
@@ -168,8 +174,8 @@ func (s *HuanProxyServer) getProxyListForwarder(remoteIP net.IP, r *http.Request
 	return ProxyList, NewForwardedList, host, proto
 }
 
-func (s *HuanProxyServer) getProxyListFromXForwardedFor(remoteIP net.IP, r *http.Request) ([]string, []string, string, string) {
-	XFroWardedForList := strings.Split(r.Header.Get("X-Forwarded-For"), ",")
+func (s *HuanProxyServer) getProxyListFromXForwardedFor(remoteIP net.IP, r *ReadOnlyRequest) ([]string, []string, string, string) {
+	XFroWardedForList := strings.Split(r.Header().Get("X-Forwarded-For"), ",")
 	ProxyList := make([]string, 0, len(XFroWardedForList)+1)
 	NewForwardedList := make([]string, 0, len(XFroWardedForList)+1)
 
@@ -180,19 +186,20 @@ func (s *HuanProxyServer) getProxyListFromXForwardedFor(remoteIP net.IP, r *http
 		}
 	}
 
-	host := r.Header.Get("X-Forwarded-Host")
-	proto := r.Header.Get("X-Forwarded-Proto")
+	host := r.Header().Get("X-Forwarded-Host")
+	proto := r.Header().Get("X-Forwarded-Proto")
 
 	if host == "" {
-		host = r.Host
+		host = r.Host()
 	}
 
 	host, _ = utils.SplitHostPort(host) // 去除host中的端口号
 
-	if proto == "" {
-		proto = "http"
-		if r.TLS != nil {
+	if proto == "http" || proto == "https" {
+		if r.IsTLS() {
 			proto = "https"
+		} else {
+			proto = "http"
 		}
 	}
 

+ 64 - 7
src/server/context.go

@@ -1,21 +1,78 @@
 package server
 
 import (
+	"fmt"
 	"github.com/SongZihuan/huan-proxy/src/config/rulescompile"
 	"net/http"
 )
 
 type Context struct {
-	Abort   bool
-	Writer  writer
-	Request *http.Request
-	Rule    *rulescompile.RuleCompileConfig
+	Abort        bool
+	Writer       http.ResponseWriter
+	Request      *ReadOnlyRequest
+	ProxyRequest *ProxyRequest
+	Rule         *rulescompile.RuleCompileConfig
 }
 
 func NewContext(rule *rulescompile.RuleCompileConfig, w http.ResponseWriter, r *http.Request) *Context {
+	var proxyRequest *ProxyRequest = nil
+	if rule.Type == rulescompile.ProxyTypeAPI {
+		proxyRequest = NewRequest(r)
+	}
+
 	return &Context{
-		Writer:  NewWriter(w),
-		Request: r,
-		Rule:    rule,
+		Writer:       NewResponseWriter(w),
+		Request:      NewReadOnlyRequest(r),
+		ProxyRequest: proxyRequest,
+		Rule:         rule,
+	}
+}
+
+func (ctx *Context) ProxyWriteToHttpRRequest() (*http.Request, error) {
+	if ctx.ProxyRequest == nil {
+		return nil, fmt.Errorf("proxy request is nil")
+	}
+
+	req, err := ctx.ProxyRequest.WriteToHttpRRequest()
+	if err != nil {
+		return nil, err
+	}
+
+	return req, nil
+}
+
+func (ctx *Context) WriteToResponse() error {
+	w, ok := ctx.Writer.(*ResponseWriter)
+	if !ok {
+		return nil
+	}
+	err := w.WriteToResponse()
+	if err != nil {
+		return err
 	}
+	return nil
+}
+
+func (ctx *Context) MustWriteToResponse() {
+	if w, ok := ctx.Writer.(*ResponseWriter); ok {
+		w.MustWriteToResponse()
+	}
+}
+
+func (ctx *Context) Reset() error {
+	ctx.Abort = false
+
+	if w, ok := ctx.Writer.(*ResponseWriter); ok {
+		_ = w.Reset()
+	}
+
+	if ctx.ProxyRequest != nil {
+		_ = ctx.ProxyRequest.Reset()
+	}
+
+	return nil
+}
+
+func (ctx *Context) Redirect(target string, code int) {
+	http.Redirect(ctx.Writer, ctx.Request.req, target, code)
 }

+ 2 - 2
src/server/cors.go

@@ -8,7 +8,7 @@ import (
 
 func (s *HuanProxyServer) cors(corsRule *corscompile.CorsCompileConfig, ctx *Context) bool {
 	if corsRule.Ignore {
-		if ctx.Request.Method == http.MethodOptions {
+		if ctx.Request.Method() == http.MethodOptions {
 			s.abortMethodNotAllowed(ctx)
 			return false
 		} else {
@@ -16,7 +16,7 @@ func (s *HuanProxyServer) cors(corsRule *corscompile.CorsCompileConfig, ctx *Con
 		}
 	}
 
-	origin := ctx.Request.Header.Get("Origin")
+	origin := ctx.Request.Header().Get("Origin")
 	if origin == "" {
 		s.abortForbidden(ctx)
 		return false

+ 7 - 0
src/server/default.go

@@ -0,0 +1,7 @@
+package server
+
+import "net/http"
+
+func (s *HuanProxyServer) defaultResponse(w http.ResponseWriter) {
+	w.WriteHeader(http.StatusNotFound)
+}

+ 28 - 28
src/server/dir.go

@@ -15,83 +15,83 @@ import (
 const IndexMaxDeep = 5
 const DefaultIgnoreFileMap = 20
 
-func (s *HuanProxyServer) dirServer(rule *rulescompile.RuleCompileConfig, w http.ResponseWriter, r *http.Request) {
-	if !s.cors(rule.Dir.Cors, w, r) {
+func (s *HuanProxyServer) dirServer(ctx *Context) {
+	if !s.cors(ctx.Rule.Dir.Cors, ctx) {
 		return
 	}
 
-	if r.Method != http.MethodGet {
-		s.abortMethodNotAllowed(w)
+	if ctx.Request.Method() != http.MethodGet {
+		s.abortMethodNotAllowed(ctx)
 		return
 	}
 
-	dirBasePath := rule.Dir.BasePath // 根部目录
-	fileAccess := ""                 // 访问目录
-	filePath := ""                   // 根部目录+访问目录=实际目录
+	dirBasePath := ctx.Rule.Dir.BasePath // 根部目录
+	fileAccess := ""                     // 访问目录
+	filePath := ""                       // 根部目录+访问目录=实际目录
 
-	url := utils.ProcessURLPath(r.URL.Path)
-	if rule.MatchType == matchcompile.RegexMatch {
-		fileAccess = s.dirRewrite("", rule.Dir.AddPath, rule.Dir.SubPath, rule.Dir.Rewrite)
+	url := utils.ProcessURLPath(ctx.Request.URL().Path)
+	if ctx.Rule.MatchType == matchcompile.RegexMatch {
+		fileAccess = s.dirRewrite("", ctx.Rule.Dir.AddPath, ctx.Rule.Dir.SubPath, ctx.Rule.Dir.Rewrite)
 		filePath = path.Join(dirBasePath, fileAccess)
 	} else {
-		if url == rule.MatchPath {
-			fileAccess = s.dirRewrite("", rule.Dir.AddPath, rule.Dir.SubPath, rule.Dir.Rewrite)
+		if url == ctx.Rule.MatchPath {
+			fileAccess = s.dirRewrite("", ctx.Rule.Dir.AddPath, ctx.Rule.Dir.SubPath, ctx.Rule.Dir.Rewrite)
 			filePath = path.Join(dirBasePath, fileAccess)
-		} else if strings.HasPrefix(url, rule.MatchPath+"/") {
-			fileAccess = s.dirRewrite(url[len(rule.MatchPath+"/"):], rule.Dir.AddPath, rule.Dir.SubPath, rule.Dir.Rewrite)
+		} else if strings.HasPrefix(url, ctx.Rule.MatchPath+"/") {
+			fileAccess = s.dirRewrite(url[len(ctx.Rule.MatchPath+"/"):], ctx.Rule.Dir.AddPath, ctx.Rule.Dir.SubPath, ctx.Rule.Dir.Rewrite)
 			filePath = path.Join(dirBasePath, fileAccess)
 		} else {
-			s.abortNotFound(w)
+			s.abortNotFound(ctx)
 			return
 		}
 	}
 
 	if filePath == "" {
-		s.abortNotFound(w) // 正常清空不会走到这个流程
+		s.abortNotFound(ctx) // 正常清空不会走到这个流程
 		return
 	}
 
 	if utils.IsFile(filePath) {
 		// 判断这个文件是否被ignore,因为ignore是从dirBasePath写起,也可以是完整路径,因此filePath和fileAccess都要判断
-		for _, ignore := range rule.Dir.IgnoreFile {
+		for _, ignore := range ctx.Rule.Dir.IgnoreFile {
 			if ignore.CheckName(fileAccess) || ignore.CheckName(filePath) {
-				s.abortNotFound(w)
+				s.abortNotFound(ctx)
 				return
 			}
 		}
 	} else {
-		filePath = s.getIndexFile(rule, filePath)
+		filePath = s.getIndexFile(ctx.Rule, filePath)
 		if filePath == "" || !utils.IsFile(filePath) {
-			s.abortNotFound(w)
+			s.abortNotFound(ctx)
 			return
 		}
 	}
 
 	if !utils.CheckIfSubPath(dirBasePath, filePath) {
-		s.abortForbidden(w)
+		s.abortForbidden(ctx)
 		return
 	}
 
 	file, err := os.ReadFile(filePath)
 	if err != nil {
-		s.abortNotFound(w)
+		s.abortNotFound(ctx)
 		return
 	}
 
 	mimeType := mimetype.Detect(file)
-	accept := r.Header.Get("Accept")
+	accept := ctx.Request.Header().Get("Accept")
 	if !utils.AcceptMimeType(accept, mimeType.String()) {
-		s.abortNotAcceptable(w)
+		s.abortNotAcceptable(ctx)
 		return
 	}
 
-	_, err = w.Write(file)
+	_, err = ctx.Writer.Write(file)
 	if err != nil {
-		s.abortServerError(w)
+		s.abortServerError(ctx)
 		return
 	}
-	w.Header().Set("Content-Type", mimeType.String())
-	s.statusOK(w)
+	ctx.Writer.Header().Set("Content-Type", mimeType.String())
+	s.statusOK(ctx)
 }
 
 func (s *HuanProxyServer) dirRewrite(srcpath string, prefix string, suffix string, rewrite *rewritecompile.RewriteCompileConfig) string {

+ 2 - 2
src/server/file.go

@@ -12,7 +12,7 @@ func (s *HuanProxyServer) fileServer(ctx *Context) {
 		return
 	}
 
-	if r.Method != http.MethodGet {
+	if ctx.Request.Method() != http.MethodGet {
 		s.abortMethodNotAllowed(ctx)
 		return
 	}
@@ -24,7 +24,7 @@ func (s *HuanProxyServer) fileServer(ctx *Context) {
 	}
 
 	mimeType := mimetype.Detect(file)
-	accept := r.Header.Get("Accept")
+	accept := ctx.Request.Header().Get("Accept")
 	if !utils.AcceptMimeType(accept, mimeType.String()) {
 		s.abortNotAcceptable(ctx)
 		return

+ 49 - 0
src/server/hptag.go

@@ -0,0 +1,49 @@
+package server
+
+import (
+	"fmt"
+	resource "github.com/SongZihuan/huan-proxy"
+	"github.com/SongZihuan/huan-proxy/src/config/rulescompile/actioncompile/apicompile"
+	"github.com/SongZihuan/huan-proxy/src/utils"
+	"net/http"
+	"strings"
+)
+
+const XHuanProxyHeaer = apicompile.XHuanProxyHeaer
+const ViaHeader = apicompile.ViaHeader
+
+func (s *HuanProxyServer) writeHuanProxyHeader(ctx *Context) {
+	version := strings.TrimSpace(utils.StringToOnlyPrint(resource.Version))
+	ctx.Writer.Header().Set(XHuanProxyHeaer, version)
+	if ctx.ProxyRequest != nil {
+		ctx.ProxyRequest.Header.Set(XHuanProxyHeaer, version)
+	}
+}
+
+func (s *HuanProxyServer) writeViaHeader(ctx *Context) {
+	info := fmt.Sprintf("%s %s", ctx.Request.MustProto(), ctx.Rule.Api.Via)
+
+	reqHeader := ctx.Request.Header().Get(ViaHeader)
+	if reqHeader == "" {
+		reqHeader = info
+	} else {
+		reqHeader = fmt.Sprintf("%s, %s", reqHeader, info)
+	}
+	ctx.Request.Header().Set(ViaHeader, reqHeader)
+
+	respHeader := ctx.Writer.Header().Get(ViaHeader)
+	if respHeader == "" {
+		respHeader = info
+	} else if !strings.Contains(respHeader, info) {
+		respHeader = fmt.Sprintf("%s, %s", respHeader, info)
+	}
+	ctx.Writer.Header().Set(ViaHeader, respHeader)
+}
+
+func (s *HuanProxyServer) statusOK(ctx *Context) {
+	ctx.Writer.WriteHeader(http.StatusOK)
+}
+
+func (s *HuanProxyServer) statusRedirect(ctx *Context, url string, code int) {
+	ctx.Redirect(url, code)
+}

+ 6 - 2
src/server/loggerserver.go

@@ -137,7 +137,11 @@ func (s *HuanProxyServer) LoggerServerHTTP(_w http.ResponseWriter, r *http.Reque
 	path := r.URL.Path
 	raw := r.URL.RawQuery
 
-	w := NewWriter(_w)
+	w, ok := NewResponseWriter(_w).(*ResponseWriter)
+	if !ok {
+		_w.WriteHeader(http.StatusInternalServerError)
+		return
+	}
 
 	// Process request
 	next(w, r)
@@ -154,7 +158,7 @@ func (s *HuanProxyServer) LoggerServerHTTP(_w http.ResponseWriter, r *http.Reque
 
 	param.RemoteAddr = r.RemoteAddr
 	param.Method = r.Method
-	param.StatusCode = w.status
+	param.StatusCode = w.Status()
 	param.BodySize = w.Size()
 
 	if raw != "" {

+ 19 - 14
src/server/serverhttp.go → src/server/normalserver.go

@@ -1,17 +1,17 @@
 package server
 
 import (
+	"github.com/SongZihuan/huan-proxy/src/config"
 	"github.com/SongZihuan/huan-proxy/src/config/rulescompile"
 	"net/http"
 )
 
 func (s *HuanProxyServer) NormalServeHTTP(w http.ResponseWriter, r *http.Request) {
-	s.writeHuanProxyHeader(w, r)
-
 	func() {
+	RuleCycle:
 		for _, rule := range s.GetRulesList() {
 			if !s.matchURL(rule, r) {
-				continue
+				continue RuleCycle
 			}
 
 			ctx := NewContext(rule, w, r)
@@ -20,24 +20,29 @@ func (s *HuanProxyServer) NormalServeHTTP(w http.ResponseWriter, r *http.Request
 				return
 			}
 
+			s.writeHuanProxyHeader(ctx)
+
 			if rule.Type == rulescompile.ProxyTypeFile {
 				s.fileServer(ctx)
-				return
 			} else if rule.Type == rulescompile.ProxyTypeDir {
-				s.dirServer(rule, w, r)
-				return
+				s.dirServer(ctx)
 			} else if rule.Type == rulescompile.ProxyTypeAPI {
-				s.apiServer(rule, w, r)
-				return
+				s.apiServer(ctx)
 			} else if rule.Type == rulescompile.ProxyTypeRedirect {
-				s.redirectServer(rule, w, r)
-				return
+				s.redirectServer(ctx)
 			} else {
-				s.abortServerError(w)
-				return
+				s.abortServerError(ctx)
 			}
-		}
 
-		s.abortNotFound(w)
+			if config.GetConfig().NotAbort.IsEnable(false) {
+				_ = ctx.Reset()
+				continue RuleCycle
+			}
+
+			ctx.MustWriteToResponse()
+			return
+
+		}
+		s.defaultResponse(w)
 	}()
 }

+ 6 - 6
src/server/proxytrust.go

@@ -10,20 +10,20 @@ func (s *HuanProxyServer) checkProxyTrust(ctx *Context) bool {
 		return true
 	}
 
-	if ctx.Request.RemoteAddr == "" {
-		s.abortForbidden(w)
+	if ctx.Request.RemoteAddr() == "" {
+		s.abortForbidden(ctx)
 		return false
 	}
 
-	remoteIPStr, _, err := net.SplitHostPort(ctx.Request.RemoteAddr)
+	remoteIPStr, _, err := net.SplitHostPort(ctx.Request.RemoteAddr())
 	if err != nil {
-		s.abortForbidden(w)
+		s.abortForbidden(ctx)
 		return false
 	}
 
 	remoteIP := net.ParseIP(remoteIPStr)
 	if remoteIP == nil {
-		s.abortForbidden(w)
+		s.abortForbidden(ctx)
 		return false
 	}
 
@@ -45,6 +45,6 @@ func (s *HuanProxyServer) checkProxyTrust(ctx *Context) bool {
 		}
 	}
 
-	s.abortForbidden(w)
+	s.abortForbidden(ctx)
 	return false
 }

+ 62 - 0
src/server/readonlyrequests.go

@@ -0,0 +1,62 @@
+package server
+
+import (
+	"github.com/SongZihuan/huan-proxy/src/utils"
+	"net/http"
+	"net/url"
+)
+
+type ReadOnlyRequest struct {
+	req    *http.Request
+	url    *url.URL
+	header http.Header
+}
+
+func NewReadOnlyRequest(req *http.Request) *ReadOnlyRequest {
+	return &ReadOnlyRequest{
+		req:    req,
+		url:    utils.URLClone(req.URL),
+		header: req.Header.Clone(),
+	}
+}
+
+func (r *ReadOnlyRequest) Host() string {
+	return r.req.Host
+}
+
+func (r *ReadOnlyRequest) Method() string {
+	return r.req.Method
+}
+
+func (r *ReadOnlyRequest) RemoteAddr() string {
+	return r.req.Host
+}
+
+func (r *ReadOnlyRequest) Proto() string {
+	return r.req.Proto
+}
+
+func (r *ReadOnlyRequest) MustProto() string {
+	proto := r.req.Proto
+	if proto == "" {
+		if r.IsTLS() {
+			return "https"
+		} else {
+			return "http"
+		}
+	} else {
+		return proto
+	}
+}
+
+func (r *ReadOnlyRequest) URL() *url.URL {
+	return utils.URLClone(r.req.URL)
+}
+
+func (r *ReadOnlyRequest) Header() http.Header {
+	return r.req.Header.Clone()
+}
+
+func (r *ReadOnlyRequest) IsTLS() bool {
+	return r.req.TLS != nil
+}

+ 4 - 6
src/server/redirect.go

@@ -2,22 +2,20 @@ package server
 
 import (
 	"fmt"
-	"github.com/SongZihuan/huan-proxy/src/config/rulescompile"
 	"github.com/SongZihuan/huan-proxy/src/config/rulescompile/actioncompile/rewritecompile"
-	"net/http"
 	"net/url"
 )
 
-func (s *HuanProxyServer) redirectServer(rule *rulescompile.RuleCompileConfig, w http.ResponseWriter, r *http.Request) {
-	target := s.redirectRewrite(rule.Redirect.Address, rule.Redirect.Rewrite)
+func (s *HuanProxyServer) redirectServer(ctx *Context) {
+	target := s.redirectRewrite(ctx.Rule.Redirect.Address, ctx.Rule.Redirect.Rewrite)
 
 	if _, err := url.Parse(target); err != nil {
-		s.abortServerError(w)
+		s.abortServerError(ctx)
 		return
 	}
 
 	fmt.Printf("target: %s\n", target)
-	s.statusRedirect(w, r, target, rule.Redirect.Code)
+	s.statusRedirect(ctx, target, ctx.Rule.Redirect.Code)
 }
 
 func (s *HuanProxyServer) redirectRewrite(address string, rewrite *rewritecompile.RewriteCompileConfig) string {

+ 99 - 0
src/server/request.go

@@ -0,0 +1,99 @@
+package server
+
+import (
+	"github.com/SongZihuan/huan-proxy/src/utils"
+	"net/http"
+	"net/url"
+)
+
+type ProxyRequest struct {
+	req *http.Request
+
+	Host       string
+	Proto      string
+	IsTLS      bool
+	Method     string
+	RemoteAddr string
+	URL        *url.URL
+	Header     http.Header
+
+	_host       string
+	_proto      string
+	_method     string
+	_remoteAddr string
+	_url        *url.URL
+	_header     http.Header
+
+	written bool
+}
+
+func NewRequest(req *http.Request) *ProxyRequest {
+	return &ProxyRequest{
+		req: req,
+
+		Host:       req.Host,
+		Proto:      req.Proto,
+		IsTLS:      req.TLS != nil,
+		Method:     req.Method,
+		RemoteAddr: req.RemoteAddr,
+		Header:     req.Header.Clone(),
+		URL:        utils.URLClone(req.URL),
+
+		_host:       req.Host,
+		_proto:      req.Proto,
+		_method:     req.Method,
+		_remoteAddr: req.RemoteAddr,
+		_url:        req.URL,
+		_header:     req.Header,
+
+		written: false,
+	}
+}
+
+func (r *ProxyRequest) ResetHttpRequest() error {
+	r.req.Host = r._host
+	r.req.Proto = r._proto
+	r.req.Method = r._method
+	r.req.RemoteAddr = r._remoteAddr
+	r.req.Header = r._header
+	r.req.URL = r._url
+	return nil
+}
+
+func (r *ProxyRequest) Reset() error {
+	err := r.ResetHttpRequest()
+	if err != nil {
+		return err
+	}
+
+	r.Host = r.req.Host
+	r.Proto = r.req.Proto
+	r.Method = r.req.Method
+	r.RemoteAddr = r.req.RemoteAddr
+	r.Header = r.req.Header.Clone()
+	r.URL = utils.URLClone(r.req.URL)
+	r.written = false
+	return nil
+}
+
+func (r *ProxyRequest) WriteToHttpRRequest() (req *http.Request, err error) {
+	if r.written {
+		return r.req, nil
+	}
+
+	defer func() {
+		if err != nil {
+			_ = r.ResetHttpRequest() // 复原所有操作
+			r.written = false
+		}
+	}()
+
+	r.req.Host = r.Host
+	r.req.Proto = r.Proto
+	r.req.Method = r.Method
+	r.req.RemoteAddr = r.RemoteAddr
+	r.req.URL = utils.URLClone(r.URL)
+	r.req.Header = r.Header.Clone()
+	r.written = true
+	return r.req, nil
+}

+ 0 - 48
src/server/respose.go

@@ -1,48 +0,0 @@
-package server
-
-import (
-	"fmt"
-	resource "github.com/SongZihuan/huan-proxy"
-	"github.com/SongZihuan/huan-proxy/src/config/rulescompile"
-	"github.com/SongZihuan/huan-proxy/src/config/rulescompile/actioncompile/apicompile"
-	"github.com/SongZihuan/huan-proxy/src/utils"
-	"net/http"
-	"strings"
-)
-
-const XHuanProxyHeaer = apicompile.XHuanProxyHeaer
-const ViaHeader = apicompile.ViaHeader
-
-func (s *HuanProxyServer) writeHuanProxyHeader(w http.ResponseWriter, r *http.Request) {
-	version := strings.TrimSpace(utils.StringToOnlyPrint(resource.Version))
-	r.Header.Set(XHuanProxyHeaer, version)
-	w.Header().Set(XHuanProxyHeaer, version)
-}
-
-func (s *HuanProxyServer) writeViaHeader(rule *rulescompile.RuleCompileConfig, w http.ResponseWriter, r *http.Request) {
-	info := fmt.Sprintf("%s %s", r.Proto, rule.Api.Via)
-
-	reqHeader := r.Header.Get(ViaHeader)
-	if reqHeader == "" {
-		reqHeader = info
-	} else {
-		reqHeader = fmt.Sprintf("%s, %s", reqHeader, info)
-	}
-	r.Header.Set(ViaHeader, reqHeader)
-
-	respHeader := w.Header().Get(ViaHeader)
-	if respHeader == "" {
-		respHeader = info
-	} else if !strings.Contains(respHeader, info) {
-		respHeader = fmt.Sprintf("%s, %s", respHeader, info)
-	}
-	w.Header().Set(ViaHeader, respHeader)
-}
-
-func (s *HuanProxyServer) statusOK(w http.ResponseWriter) {
-	w.WriteHeader(http.StatusOK)
-}
-
-func (s *HuanProxyServer) statusRedirect(w http.ResponseWriter, r *http.Request, url string, code int) {
-	http.Redirect(w, r, url, code)
-}

+ 50 - 17
src/server/writer.go

@@ -2,56 +2,80 @@ package server
 
 import (
 	"bytes"
+	"fmt"
 	"net/http"
 )
 
-type writer http.ResponseWriter
+var ErrHasWriter = fmt.Errorf("ResponseWriter has been written")
 
 type ResponseWriter struct {
-	writer
-	status  int
-	buffer  bytes.Buffer
+	writer http.ResponseWriter
+
+	status int
+	buffer bytes.Buffer
+	header http.Header
+
 	written bool
-	header  http.Header
 }
 
-func NewWriter(w writer) *ResponseWriter {
-	res := &ResponseWriter{
+func NewResponseWriter(w http.ResponseWriter) http.ResponseWriter {
+	if _, ok := w.(*ResponseWriter); ok {
+		return w
+	}
+
+	return &ResponseWriter{
 		writer: w,
+
 		status: 0,
-		header: make(http.Header, 10),
-	}
+		header: w.Header().Clone(),
 
-	for n, h := range w.Header() {
-		nh := make([]string, 0, len(h))
-		copy(nh, h)
-		res.header[n] = nh
+		written: false,
 	}
-
-	return res
 }
 
 func (r *ResponseWriter) Size() int64 {
 	return int64(r.buffer.Len())
 }
 
+func (r *ResponseWriter) Status() int {
+	return r.status
+}
+
 func (r *ResponseWriter) Write(p []byte) (int, error) {
+	if r.written {
+		return 0, ErrHasWriter
+	}
+
 	return r.buffer.Write(p)
 }
 
 func (r *ResponseWriter) WriteHeader(statusCode int) {
+	if r.written {
+		return
+	}
+
 	r.status = statusCode
 }
 
 func (r *ResponseWriter) Header() http.Header {
+	if r.written {
+		return nil
+	}
+
 	return r.header
 }
 
-func (r *ResponseWriter) Reset() {
+func (r *ResponseWriter) Reset() error {
+	if r.written {
+		return ErrHasWriter
+	}
+
 	r.status = 0
-	r.header = make(http.Header, 10)
 	r.buffer.Reset()
+	r.header = r.writer.Header()
 	r.written = false
+
+	return nil
 }
 
 func (r *ResponseWriter) WriteToResponse() error {
@@ -87,3 +111,12 @@ func (r *ResponseWriter) WriteToResponse() error {
 	r.written = true
 	return nil
 }
+
+func (r *ResponseWriter) MustWriteToResponse() {
+	err := r.WriteToResponse()
+	if err == nil {
+		return
+	}
+
+	r.writer.WriteHeader(http.StatusInternalServerError)
+}

+ 13 - 1
src/utils/url.go

@@ -1,6 +1,9 @@
 package utils
 
-import "strings"
+import (
+	"net/url"
+	"strings"
+)
 
 /*
 设计理念:
@@ -59,3 +62,12 @@ func validOptionalPort(port string) bool {
 	}
 	return true
 }
+
+func URLClone(old *url.URL) *url.URL {
+	reqURL := *old
+	if old.User != nil {
+		reqUser := *old.User
+		reqURL.User = &reqUser
+	}
+	return &reqURL
+}