contentsecurityhandler_test.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. package handler
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "encoding/base64"
  6. "fmt"
  7. "io"
  8. "log"
  9. "net/http"
  10. "net/http/httptest"
  11. "net/url"
  12. "os"
  13. "strconv"
  14. "strings"
  15. "testing"
  16. "time"
  17. "github.com/stretchr/testify/assert"
  18. "github.com/zeromicro/go-zero/core/codec"
  19. "github.com/zeromicro/go-zero/rest/httpx"
  20. )
  21. const timeDiff = time.Hour * 2 * 24
  22. var (
  23. fingerprint = "12345"
  24. pubKey = []byte(`-----BEGIN PUBLIC KEY-----
  25. MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE
  26. eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH
  27. miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR
  28. my47YlhspwszKdRP+wIDAQAB
  29. -----END PUBLIC KEY-----`)
  30. priKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
  31. MIICXAIBAAKBgQD7bq4FLG0ctccbEFEsUBuRxkjEeJ5U+0CAEjJk20V9/u2Fu76i
  32. 1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVHmiYbRgh5Fy6336KepLCtCmV/
  33. r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwRmy47YlhspwszKdRP+wIDAQAB
  34. AoGBANs1qf7UtuSbD1ZnKX5K8V5s07CHwPMygw+lzc3k5ndtNUStZQ2vnAaBXHyH
  35. Nm4lJ4AI2mhQ39jQB/1TyP1uAzvpLhT60fRybEq9zgJ/81Gm9bnaEpFJ9bP2bBrY
  36. J0jbaTMfbzL/PJFl3J3RGMR40C76h5yRYSnOpMoMiKWnJqrhAkEA/zCOkR+34Pk0
  37. Yo3sIP4ranY6AAvwacgNaui4ll5xeYwv3iLOQvPlpxIxFHKXEY0klNNyjjXqgYjP
  38. cOenqtt6UwJBAPw7EYuteVHvHvQVuTbKAaYHcOrp4nFeZF3ndFfl0w2dwGhfzcXO
  39. ROyd5dNQCuCWRo8JBpjG6PFyzezayF4KLrkCQCGditoxHG7FRRJKcbVy5dMzWbaR
  40. 3AyDLslLeK1OKZKCVffkC9mj+TeF3PM9mQrV1eDI7ckv7wE7PWA5E8wc90MCQEOV
  41. MCZU3OTvRUPxbicYCUkLRV4sPNhTimD+21WR5vMHCb7trJ0Ln7wmsqXkFIYIve8l
  42. Y/cblN7c/AAyvu0znUECQA318nPldsxR6+H8HTS3uEbkL4UJdjQJHsvTwKxAw5qc
  43. moKExvRlN0zmGGuArKcqS38KG7PXZMrUv3FXPdp6BDQ=
  44. -----END RSA PRIVATE KEY-----`)
  45. key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
  46. )
  47. type requestSettings struct {
  48. method string
  49. url string
  50. body io.Reader
  51. strict bool
  52. crypt bool
  53. requestUri string
  54. timestamp int64
  55. fingerprint string
  56. missHeader bool
  57. signature string
  58. }
  59. func init() {
  60. log.SetOutput(io.Discard)
  61. }
  62. func TestContentSecurityHandler(t *testing.T) {
  63. tests := []struct {
  64. method string
  65. url string
  66. body string
  67. strict bool
  68. crypt bool
  69. requestUri string
  70. timestamp int64
  71. fingerprint string
  72. missHeader bool
  73. signature string
  74. statusCode int
  75. }{
  76. {
  77. method: http.MethodGet,
  78. url: "http://localhost/a/b?c=d&e=f",
  79. strict: true,
  80. crypt: false,
  81. },
  82. {
  83. method: http.MethodPost,
  84. url: "http://localhost/a/b?c=d&e=f",
  85. body: "hello",
  86. strict: true,
  87. crypt: false,
  88. },
  89. {
  90. method: http.MethodGet,
  91. url: "http://localhost/a/b?c=d&e=f",
  92. strict: true,
  93. crypt: true,
  94. },
  95. {
  96. method: http.MethodPost,
  97. url: "http://localhost/a/b?c=d&e=f",
  98. body: "hello",
  99. strict: true,
  100. crypt: true,
  101. },
  102. {
  103. method: http.MethodGet,
  104. url: "http://localhost/a/b?c=d&e=f",
  105. strict: true,
  106. crypt: true,
  107. timestamp: time.Now().Add(timeDiff).Unix(),
  108. statusCode: http.StatusForbidden,
  109. },
  110. {
  111. method: http.MethodPost,
  112. url: "http://localhost/a/b?c=d&e=f",
  113. body: "hello",
  114. strict: true,
  115. crypt: true,
  116. timestamp: time.Now().Add(-timeDiff).Unix(),
  117. statusCode: http.StatusForbidden,
  118. },
  119. {
  120. method: http.MethodPost,
  121. url: "http://remotehost/",
  122. body: "hello",
  123. strict: true,
  124. crypt: true,
  125. requestUri: "http://localhost/a/b?c=d&e=f",
  126. },
  127. {
  128. method: http.MethodPost,
  129. url: "http://localhost/a/b?c=d&e=f",
  130. body: "hello",
  131. strict: false,
  132. crypt: true,
  133. fingerprint: "badone",
  134. },
  135. {
  136. method: http.MethodPost,
  137. url: "http://localhost/a/b?c=d&e=f",
  138. body: "hello",
  139. strict: true,
  140. crypt: true,
  141. timestamp: time.Now().Add(-timeDiff).Unix(),
  142. fingerprint: "badone",
  143. statusCode: http.StatusForbidden,
  144. },
  145. {
  146. method: http.MethodPost,
  147. url: "http://localhost/a/b?c=d&e=f",
  148. body: "hello",
  149. strict: true,
  150. crypt: true,
  151. missHeader: true,
  152. statusCode: http.StatusForbidden,
  153. },
  154. {
  155. method: http.MethodHead,
  156. url: "http://localhost/a/b?c=d&e=f",
  157. strict: true,
  158. crypt: false,
  159. },
  160. {
  161. method: http.MethodGet,
  162. url: "http://localhost/a/b?c=d&e=f",
  163. strict: true,
  164. crypt: false,
  165. signature: "badone",
  166. statusCode: http.StatusForbidden,
  167. },
  168. }
  169. for _, test := range tests {
  170. t.Run(test.url, func(t *testing.T) {
  171. if test.statusCode == 0 {
  172. test.statusCode = http.StatusOK
  173. }
  174. if len(test.fingerprint) == 0 {
  175. test.fingerprint = fingerprint
  176. }
  177. if test.timestamp == 0 {
  178. test.timestamp = time.Now().Unix()
  179. }
  180. func() {
  181. keyFile, err := createTempFile(priKey)
  182. defer os.Remove(keyFile)
  183. assert.Nil(t, err)
  184. decrypter, err := codec.NewRsaDecrypter(keyFile)
  185. assert.Nil(t, err)
  186. contentSecurityHandler := ContentSecurityHandler(map[string]codec.RsaDecrypter{
  187. fingerprint: decrypter,
  188. }, time.Hour, test.strict)
  189. handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  190. }))
  191. var reader io.Reader
  192. if len(test.body) > 0 {
  193. reader = strings.NewReader(test.body)
  194. }
  195. setting := requestSettings{
  196. method: test.method,
  197. url: test.url,
  198. body: reader,
  199. strict: test.strict,
  200. crypt: test.crypt,
  201. requestUri: test.requestUri,
  202. timestamp: test.timestamp,
  203. fingerprint: test.fingerprint,
  204. missHeader: test.missHeader,
  205. signature: test.signature,
  206. }
  207. req, err := buildRequest(setting)
  208. assert.Nil(t, err)
  209. resp := httptest.NewRecorder()
  210. handler.ServeHTTP(resp, req)
  211. assert.Equal(t, test.statusCode, resp.Code)
  212. }()
  213. })
  214. }
  215. }
  216. func TestContentSecurityHandler_UnsignedCallback(t *testing.T) {
  217. keyFile, err := createTempFile(priKey)
  218. defer os.Remove(keyFile)
  219. assert.Nil(t, err)
  220. decrypter, err := codec.NewRsaDecrypter(keyFile)
  221. assert.Nil(t, err)
  222. contentSecurityHandler := ContentSecurityHandler(
  223. map[string]codec.RsaDecrypter{
  224. fingerprint: decrypter,
  225. },
  226. time.Hour,
  227. true,
  228. func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
  229. w.WriteHeader(http.StatusOK)
  230. })
  231. handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
  232. setting := requestSettings{
  233. method: http.MethodGet,
  234. url: "http://localhost/a/b?c=d&e=f",
  235. signature: "badone",
  236. }
  237. req, err := buildRequest(setting)
  238. assert.Nil(t, err)
  239. resp := httptest.NewRecorder()
  240. handler.ServeHTTP(resp, req)
  241. assert.Equal(t, http.StatusOK, resp.Code)
  242. }
  243. func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) {
  244. keyFile, err := createTempFile(priKey)
  245. defer os.Remove(keyFile)
  246. assert.Nil(t, err)
  247. decrypter, err := codec.NewRsaDecrypter(keyFile)
  248. assert.Nil(t, err)
  249. contentSecurityHandler := ContentSecurityHandler(
  250. map[string]codec.RsaDecrypter{
  251. fingerprint: decrypter,
  252. },
  253. time.Hour,
  254. true,
  255. func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
  256. assert.Equal(t, httpx.CodeSignatureWrongTime, code)
  257. w.WriteHeader(http.StatusOK)
  258. })
  259. handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
  260. reader := strings.NewReader("hello")
  261. setting := requestSettings{
  262. method: http.MethodPost,
  263. url: "http://localhost/a/b?c=d&e=f",
  264. body: reader,
  265. strict: true,
  266. crypt: true,
  267. timestamp: time.Now().Add(time.Hour * 24 * 365).Unix(),
  268. fingerprint: fingerprint,
  269. }
  270. req, err := buildRequest(setting)
  271. assert.Nil(t, err)
  272. resp := httptest.NewRecorder()
  273. handler.ServeHTTP(resp, req)
  274. assert.Equal(t, http.StatusOK, resp.Code)
  275. }
  276. func buildRequest(rs requestSettings) (*http.Request, error) {
  277. var bodyStr string
  278. var err error
  279. if rs.crypt && rs.body != nil {
  280. var buf bytes.Buffer
  281. io.Copy(&buf, rs.body)
  282. bodyBytes, err := codec.EcbEncrypt(key, buf.Bytes())
  283. if err != nil {
  284. return nil, err
  285. }
  286. bodyStr = base64.StdEncoding.EncodeToString(bodyBytes)
  287. }
  288. r := httptest.NewRequest(rs.method, rs.url, strings.NewReader(bodyStr))
  289. if len(rs.signature) == 0 {
  290. sha := sha256.New()
  291. sha.Write([]byte(bodyStr))
  292. bodySign := fmt.Sprintf("%x", sha.Sum(nil))
  293. var path string
  294. var query string
  295. if len(rs.requestUri) > 0 {
  296. u, err := url.Parse(rs.requestUri)
  297. if err != nil {
  298. return nil, err
  299. }
  300. path = u.Path
  301. query = u.RawQuery
  302. } else {
  303. path = r.URL.Path
  304. query = r.URL.RawQuery
  305. }
  306. contentOfSign := strings.Join([]string{
  307. strconv.FormatInt(rs.timestamp, 10),
  308. rs.method,
  309. path,
  310. query,
  311. bodySign,
  312. }, "\n")
  313. rs.signature = codec.HmacBase64([]byte(key), contentOfSign)
  314. }
  315. var mode string
  316. if rs.crypt {
  317. mode = "1"
  318. } else {
  319. mode = "0"
  320. }
  321. content := strings.Join([]string{
  322. "version=v1",
  323. "type=" + mode,
  324. fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)),
  325. "time=" + strconv.FormatInt(rs.timestamp, 10),
  326. }, "; ")
  327. encrypter, err := codec.NewRsaEncrypter([]byte(pubKey))
  328. if err != nil {
  329. log.Fatal(err)
  330. }
  331. output, err := encrypter.Encrypt([]byte(content))
  332. if err != nil {
  333. log.Fatal(err)
  334. }
  335. encryptedContent := base64.StdEncoding.EncodeToString(output)
  336. if !rs.missHeader {
  337. r.Header.Set(httpx.ContentSecurity, strings.Join([]string{
  338. fmt.Sprintf("key=%s", rs.fingerprint),
  339. "secret=" + encryptedContent,
  340. "signature=" + rs.signature,
  341. }, "; "))
  342. }
  343. if len(rs.requestUri) > 0 {
  344. r.Header.Set("X-Request-Uri", rs.requestUri)
  345. }
  346. return r, nil
  347. }
  348. func createTempFile(body []byte) (string, error) {
  349. tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
  350. if err != nil {
  351. return "", err
  352. }
  353. tmpFile.Close()
  354. err = os.WriteFile(tmpFile.Name(), body, os.ModePerm)
  355. if err != nil {
  356. return "", err
  357. }
  358. return tmpFile.Name(), nil
  359. }