contentsecurityhandler_test.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. package handler
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "encoding/base64"
  6. "fmt"
  7. "io"
  8. "log"
  9. "net/http"
  10. "net/http/httptest"
  11. "net/url"
  12. "os"
  13. "strconv"
  14. "strings"
  15. "testing"
  16. "time"
  17. "github.com/stretchr/testify/assert"
  18. "github.com/wuntsong-org/go-zero-plus/core/codec"
  19. "github.com/wuntsong-org/go-zero-plus/rest/httpx"
  20. )
  21. const timeDiff = time.Hour * 2 * 24
  22. var (
  23. fingerprint = "12345"
  24. pubKey = []byte(`-----BEGIN PUBLIC KEY-----
  25. MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE
  26. eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH
  27. miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR
  28. my47YlhspwszKdRP+wIDAQAB
  29. -----END PUBLIC KEY-----`)
  30. priKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
  31. MIICXAIBAAKBgQD7bq4FLG0ctccbEFEsUBuRxkjEeJ5U+0CAEjJk20V9/u2Fu76i
  32. 1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVHmiYbRgh5Fy6336KepLCtCmV/
  33. r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwRmy47YlhspwszKdRP+wIDAQAB
  34. AoGBANs1qf7UtuSbD1ZnKX5K8V5s07CHwPMygw+lzc3k5ndtNUStZQ2vnAaBXHyH
  35. Nm4lJ4AI2mhQ39jQB/1TyP1uAzvpLhT60fRybEq9zgJ/81Gm9bnaEpFJ9bP2bBrY
  36. J0jbaTMfbzL/PJFl3J3RGMR40C76h5yRYSnOpMoMiKWnJqrhAkEA/zCOkR+34Pk0
  37. Yo3sIP4ranY6AAvwacgNaui4ll5xeYwv3iLOQvPlpxIxFHKXEY0klNNyjjXqgYjP
  38. cOenqtt6UwJBAPw7EYuteVHvHvQVuTbKAaYHcOrp4nFeZF3ndFfl0w2dwGhfzcXO
  39. ROyd5dNQCuCWRo8JBpjG6PFyzezayF4KLrkCQCGditoxHG7FRRJKcbVy5dMzWbaR
  40. 3AyDLslLeK1OKZKCVffkC9mj+TeF3PM9mQrV1eDI7ckv7wE7PWA5E8wc90MCQEOV
  41. MCZU3OTvRUPxbicYCUkLRV4sPNhTimD+21WR5vMHCb7trJ0Ln7wmsqXkFIYIve8l
  42. Y/cblN7c/AAyvu0znUECQA318nPldsxR6+H8HTS3uEbkL4UJdjQJHsvTwKxAw5qc
  43. moKExvRlN0zmGGuArKcqS38KG7PXZMrUv3FXPdp6BDQ=
  44. -----END RSA PRIVATE KEY-----`)
  45. key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
  46. )
  47. type requestSettings struct {
  48. method string
  49. url string
  50. body io.Reader
  51. strict bool
  52. crypt bool
  53. requestUri string
  54. timestamp int64
  55. fingerprint string
  56. missHeader bool
  57. signature string
  58. }
  59. func TestContentSecurityHandler(t *testing.T) {
  60. tests := []struct {
  61. method string
  62. url string
  63. body string
  64. strict bool
  65. crypt bool
  66. requestUri string
  67. timestamp int64
  68. fingerprint string
  69. missHeader bool
  70. signature string
  71. statusCode int
  72. }{
  73. {
  74. method: http.MethodGet,
  75. url: "http://localhost/a/b?c=d&e=f",
  76. strict: true,
  77. crypt: false,
  78. },
  79. {
  80. method: http.MethodPost,
  81. url: "http://localhost/a/b?c=d&e=f",
  82. body: "hello",
  83. strict: true,
  84. crypt: false,
  85. },
  86. {
  87. method: http.MethodGet,
  88. url: "http://localhost/a/b?c=d&e=f",
  89. strict: true,
  90. crypt: true,
  91. },
  92. {
  93. method: http.MethodPost,
  94. url: "http://localhost/a/b?c=d&e=f",
  95. body: "hello",
  96. strict: true,
  97. crypt: true,
  98. },
  99. {
  100. method: http.MethodGet,
  101. url: "http://localhost/a/b?c=d&e=f",
  102. strict: true,
  103. crypt: true,
  104. timestamp: time.Now().Add(timeDiff).Unix(),
  105. statusCode: http.StatusForbidden,
  106. },
  107. {
  108. method: http.MethodPost,
  109. url: "http://localhost/a/b?c=d&e=f",
  110. body: "hello",
  111. strict: true,
  112. crypt: true,
  113. timestamp: time.Now().Add(-timeDiff).Unix(),
  114. statusCode: http.StatusForbidden,
  115. },
  116. {
  117. method: http.MethodPost,
  118. url: "http://remotehost/",
  119. body: "hello",
  120. strict: true,
  121. crypt: true,
  122. requestUri: "http://localhost/a/b?c=d&e=f",
  123. },
  124. {
  125. method: http.MethodPost,
  126. url: "http://localhost/a/b?c=d&e=f",
  127. body: "hello",
  128. strict: false,
  129. crypt: true,
  130. fingerprint: "badone",
  131. },
  132. {
  133. method: http.MethodPost,
  134. url: "http://localhost/a/b?c=d&e=f",
  135. body: "hello",
  136. strict: true,
  137. crypt: true,
  138. timestamp: time.Now().Add(-timeDiff).Unix(),
  139. fingerprint: "badone",
  140. statusCode: http.StatusForbidden,
  141. },
  142. {
  143. method: http.MethodPost,
  144. url: "http://localhost/a/b?c=d&e=f",
  145. body: "hello",
  146. strict: true,
  147. crypt: true,
  148. missHeader: true,
  149. statusCode: http.StatusForbidden,
  150. },
  151. {
  152. method: http.MethodHead,
  153. url: "http://localhost/a/b?c=d&e=f",
  154. strict: true,
  155. crypt: false,
  156. },
  157. {
  158. method: http.MethodGet,
  159. url: "http://localhost/a/b?c=d&e=f",
  160. strict: true,
  161. crypt: false,
  162. signature: "badone",
  163. statusCode: http.StatusForbidden,
  164. },
  165. }
  166. for _, test := range tests {
  167. t.Run(test.url, func(t *testing.T) {
  168. if test.statusCode == 0 {
  169. test.statusCode = http.StatusOK
  170. }
  171. if len(test.fingerprint) == 0 {
  172. test.fingerprint = fingerprint
  173. }
  174. if test.timestamp == 0 {
  175. test.timestamp = time.Now().Unix()
  176. }
  177. func() {
  178. keyFile, err := createTempFile(priKey)
  179. defer os.Remove(keyFile)
  180. assert.Nil(t, err)
  181. decrypter, err := codec.NewRsaDecrypter(keyFile)
  182. assert.Nil(t, err)
  183. contentSecurityHandler := ContentSecurityHandler(map[string]codec.RsaDecrypter{
  184. fingerprint: decrypter,
  185. }, time.Hour, test.strict)
  186. handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  187. }))
  188. var reader io.Reader
  189. if len(test.body) > 0 {
  190. reader = strings.NewReader(test.body)
  191. }
  192. setting := requestSettings{
  193. method: test.method,
  194. url: test.url,
  195. body: reader,
  196. strict: test.strict,
  197. crypt: test.crypt,
  198. requestUri: test.requestUri,
  199. timestamp: test.timestamp,
  200. fingerprint: test.fingerprint,
  201. missHeader: test.missHeader,
  202. signature: test.signature,
  203. }
  204. req, err := buildRequest(setting)
  205. assert.Nil(t, err)
  206. resp := httptest.NewRecorder()
  207. handler.ServeHTTP(resp, req)
  208. assert.Equal(t, test.statusCode, resp.Code)
  209. }()
  210. })
  211. }
  212. }
  213. func TestContentSecurityHandler_UnsignedCallback(t *testing.T) {
  214. keyFile, err := createTempFile(priKey)
  215. defer os.Remove(keyFile)
  216. assert.Nil(t, err)
  217. decrypter, err := codec.NewRsaDecrypter(keyFile)
  218. assert.Nil(t, err)
  219. contentSecurityHandler := ContentSecurityHandler(
  220. map[string]codec.RsaDecrypter{
  221. fingerprint: decrypter,
  222. },
  223. time.Hour,
  224. true,
  225. func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
  226. w.WriteHeader(http.StatusOK)
  227. })
  228. handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
  229. setting := requestSettings{
  230. method: http.MethodGet,
  231. url: "http://localhost/a/b?c=d&e=f",
  232. signature: "badone",
  233. }
  234. req, err := buildRequest(setting)
  235. assert.Nil(t, err)
  236. resp := httptest.NewRecorder()
  237. handler.ServeHTTP(resp, req)
  238. assert.Equal(t, http.StatusOK, resp.Code)
  239. }
  240. func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) {
  241. keyFile, err := createTempFile(priKey)
  242. defer os.Remove(keyFile)
  243. assert.Nil(t, err)
  244. decrypter, err := codec.NewRsaDecrypter(keyFile)
  245. assert.Nil(t, err)
  246. contentSecurityHandler := ContentSecurityHandler(
  247. map[string]codec.RsaDecrypter{
  248. fingerprint: decrypter,
  249. },
  250. time.Hour,
  251. true,
  252. func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
  253. assert.Equal(t, httpx.CodeSignatureWrongTime, code)
  254. w.WriteHeader(http.StatusOK)
  255. })
  256. handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
  257. reader := strings.NewReader("hello")
  258. setting := requestSettings{
  259. method: http.MethodPost,
  260. url: "http://localhost/a/b?c=d&e=f",
  261. body: reader,
  262. strict: true,
  263. crypt: true,
  264. timestamp: time.Now().Add(time.Hour * 24 * 365).Unix(),
  265. fingerprint: fingerprint,
  266. }
  267. req, err := buildRequest(setting)
  268. assert.Nil(t, err)
  269. resp := httptest.NewRecorder()
  270. handler.ServeHTTP(resp, req)
  271. assert.Equal(t, http.StatusOK, resp.Code)
  272. }
  273. func buildRequest(rs requestSettings) (*http.Request, error) {
  274. var bodyStr string
  275. var err error
  276. if rs.crypt && rs.body != nil {
  277. var buf bytes.Buffer
  278. io.Copy(&buf, rs.body)
  279. bodyBytes, err := codec.EcbEncrypt(key, buf.Bytes())
  280. if err != nil {
  281. return nil, err
  282. }
  283. bodyStr = base64.StdEncoding.EncodeToString(bodyBytes)
  284. }
  285. r := httptest.NewRequest(rs.method, rs.url, strings.NewReader(bodyStr))
  286. if len(rs.signature) == 0 {
  287. sha := sha256.New()
  288. sha.Write([]byte(bodyStr))
  289. bodySign := fmt.Sprintf("%x", sha.Sum(nil))
  290. var path string
  291. var query string
  292. if len(rs.requestUri) > 0 {
  293. u, err := url.Parse(rs.requestUri)
  294. if err != nil {
  295. return nil, err
  296. }
  297. path = u.Path
  298. query = u.RawQuery
  299. } else {
  300. path = r.URL.Path
  301. query = r.URL.RawQuery
  302. }
  303. contentOfSign := strings.Join([]string{
  304. strconv.FormatInt(rs.timestamp, 10),
  305. rs.method,
  306. path,
  307. query,
  308. bodySign,
  309. }, "\n")
  310. rs.signature = codec.HmacBase64([]byte(key), contentOfSign)
  311. }
  312. var mode string
  313. if rs.crypt {
  314. mode = "1"
  315. } else {
  316. mode = "0"
  317. }
  318. content := strings.Join([]string{
  319. "version=v1",
  320. "type=" + mode,
  321. fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)),
  322. "time=" + strconv.FormatInt(rs.timestamp, 10),
  323. }, "; ")
  324. encrypter, err := codec.NewRsaEncrypter(pubKey)
  325. if err != nil {
  326. log.Fatal(err)
  327. }
  328. output, err := encrypter.Encrypt([]byte(content))
  329. if err != nil {
  330. log.Fatal(err)
  331. }
  332. encryptedContent := base64.StdEncoding.EncodeToString(output)
  333. if !rs.missHeader {
  334. r.Header.Set(httpx.ContentSecurity, strings.Join([]string{
  335. fmt.Sprintf("key=%s", rs.fingerprint),
  336. "secret=" + encryptedContent,
  337. "signature=" + rs.signature,
  338. }, "; "))
  339. }
  340. if len(rs.requestUri) > 0 {
  341. r.Header.Set("X-Request-Uri", rs.requestUri)
  342. }
  343. return r, nil
  344. }
  345. func createTempFile(body []byte) (string, error) {
  346. tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
  347. if err != nil {
  348. return "", err
  349. }
  350. tmpFile.Close()
  351. if err = os.WriteFile(tmpFile.Name(), body, os.ModePerm); err != nil {
  352. return "", err
  353. }
  354. return tmpFile.Name(), nil
  355. }