|
@@ -1,30 +1,89 @@
|
|
|
package server
|
|
|
|
|
|
-import "net/http"
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "net/http"
|
|
|
+)
|
|
|
+
|
|
|
+type writer http.ResponseWriter
|
|
|
|
|
|
type ResponseWriter struct {
|
|
|
- http.ResponseWriter
|
|
|
- Status int
|
|
|
- Size int64
|
|
|
+ writer
|
|
|
+ status int
|
|
|
+ buffer bytes.Buffer
|
|
|
+ written bool
|
|
|
+ header http.Header
|
|
|
}
|
|
|
|
|
|
-func NewWriter(w http.ResponseWriter) *ResponseWriter {
|
|
|
- return &ResponseWriter{
|
|
|
- ResponseWriter: w,
|
|
|
- Status: 0,
|
|
|
+func NewWriter(w writer) *ResponseWriter {
|
|
|
+ res := &ResponseWriter{
|
|
|
+ writer: w,
|
|
|
+ status: 0,
|
|
|
+ header: make(http.Header, 10),
|
|
|
+ }
|
|
|
+
|
|
|
+ for n, h := range w.Header() {
|
|
|
+ nh := make([]string, 0, len(h))
|
|
|
+ copy(nh, h)
|
|
|
+ res.header[n] = nh
|
|
|
}
|
|
|
+
|
|
|
+ return res
|
|
|
+}
|
|
|
+
|
|
|
+func (r *ResponseWriter) Size() int64 {
|
|
|
+ return int64(r.buffer.Len())
|
|
|
}
|
|
|
|
|
|
func (r *ResponseWriter) Write(p []byte) (int, error) {
|
|
|
- n, err := r.ResponseWriter.Write(p)
|
|
|
- if err != nil {
|
|
|
- return n, err
|
|
|
- }
|
|
|
- r.Size += int64(n)
|
|
|
- return n, nil
|
|
|
+ return r.buffer.Write(p)
|
|
|
}
|
|
|
|
|
|
func (r *ResponseWriter) WriteHeader(statusCode int) {
|
|
|
- r.Status = statusCode
|
|
|
- r.ResponseWriter.WriteHeader(statusCode)
|
|
|
+ r.status = statusCode
|
|
|
+}
|
|
|
+
|
|
|
+func (r *ResponseWriter) Header() http.Header {
|
|
|
+ return r.header
|
|
|
+}
|
|
|
+
|
|
|
+func (r *ResponseWriter) Reset() {
|
|
|
+ r.status = 0
|
|
|
+ r.header = make(http.Header, 10)
|
|
|
+ r.buffer.Reset()
|
|
|
+ r.written = false
|
|
|
+}
|
|
|
+
|
|
|
+func (r *ResponseWriter) WriteToResponse() error {
|
|
|
+ if r.written {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ _, err := r.writer.Write(r.buffer.Bytes())
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ r.writer.WriteHeader(r.status)
|
|
|
+
|
|
|
+ writerHeader := r.writer.Header()
|
|
|
+ for n, h := range r.header {
|
|
|
+ nh := make([]string, 0, len(h))
|
|
|
+ copy(nh, h)
|
|
|
+ writerHeader[n] = nh
|
|
|
+ }
|
|
|
+
|
|
|
+ delHeader := make([]string, 0, 10)
|
|
|
+ for n, _ := range writerHeader {
|
|
|
+ if _, ok := r.header[n]; !ok {
|
|
|
+ delHeader = append(delHeader, n)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, n := range delHeader {
|
|
|
+ delete(writerHeader, n)
|
|
|
+ }
|
|
|
+
|
|
|
+ r.written = true
|
|
|
+ return nil
|
|
|
}
|