engine_test.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. package rest
  2. import (
  3. "context"
  4. "errors"
  5. "net/http"
  6. "net/http/httptest"
  7. "sync/atomic"
  8. "testing"
  9. "time"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/zeromicro/go-zero/core/conf"
  12. "github.com/zeromicro/go-zero/core/logx"
  13. )
  14. func TestNewEngine(t *testing.T) {
  15. yamls := []string{
  16. `Name: foo
  17. Port: 54321
  18. `,
  19. `Name: foo
  20. Port: 54321
  21. CpuThreshold: 500
  22. `,
  23. `Name: foo
  24. Port: 54321
  25. CpuThreshold: 500
  26. Verbose: true
  27. `,
  28. }
  29. routes := []featuredRoutes{
  30. {
  31. jwt: jwtSetting{},
  32. signature: signatureSetting{},
  33. routes: []Route{{
  34. Method: http.MethodGet,
  35. Path: "/",
  36. Handler: func(w http.ResponseWriter, r *http.Request) {},
  37. }},
  38. },
  39. {
  40. priority: true,
  41. jwt: jwtSetting{},
  42. signature: signatureSetting{},
  43. routes: []Route{{
  44. Method: http.MethodGet,
  45. Path: "/",
  46. Handler: func(w http.ResponseWriter, r *http.Request) {},
  47. }},
  48. },
  49. {
  50. priority: true,
  51. jwt: jwtSetting{
  52. enabled: true,
  53. },
  54. signature: signatureSetting{},
  55. routes: []Route{{
  56. Method: http.MethodGet,
  57. Path: "/",
  58. Handler: func(w http.ResponseWriter, r *http.Request) {},
  59. }},
  60. },
  61. {
  62. priority: true,
  63. jwt: jwtSetting{
  64. enabled: true,
  65. prevSecret: "thesecret",
  66. },
  67. signature: signatureSetting{},
  68. routes: []Route{{
  69. Method: http.MethodGet,
  70. Path: "/",
  71. Handler: func(w http.ResponseWriter, r *http.Request) {},
  72. }},
  73. },
  74. {
  75. priority: true,
  76. jwt: jwtSetting{
  77. enabled: true,
  78. },
  79. signature: signatureSetting{},
  80. routes: []Route{{
  81. Method: http.MethodGet,
  82. Path: "/",
  83. Handler: func(w http.ResponseWriter, r *http.Request) {},
  84. }},
  85. },
  86. {
  87. priority: true,
  88. jwt: jwtSetting{
  89. enabled: true,
  90. },
  91. signature: signatureSetting{
  92. enabled: true,
  93. },
  94. routes: []Route{{
  95. Method: http.MethodGet,
  96. Path: "/",
  97. Handler: func(w http.ResponseWriter, r *http.Request) {},
  98. }},
  99. },
  100. {
  101. priority: true,
  102. jwt: jwtSetting{
  103. enabled: true,
  104. },
  105. signature: signatureSetting{
  106. enabled: true,
  107. SignatureConf: SignatureConf{
  108. Strict: true,
  109. },
  110. },
  111. routes: []Route{{
  112. Method: http.MethodGet,
  113. Path: "/",
  114. Handler: func(w http.ResponseWriter, r *http.Request) {},
  115. }},
  116. },
  117. {
  118. priority: true,
  119. jwt: jwtSetting{
  120. enabled: true,
  121. },
  122. signature: signatureSetting{
  123. enabled: true,
  124. SignatureConf: SignatureConf{
  125. Strict: true,
  126. PrivateKeys: []PrivateKeyConf{
  127. {
  128. Fingerprint: "a",
  129. KeyFile: "b",
  130. },
  131. },
  132. },
  133. },
  134. routes: []Route{{
  135. Method: http.MethodGet,
  136. Path: "/",
  137. Handler: func(w http.ResponseWriter, r *http.Request) {},
  138. }},
  139. },
  140. }
  141. for _, yaml := range yamls {
  142. for _, route := range routes {
  143. var cnf RestConf
  144. assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf))
  145. ng := newEngine(cnf)
  146. ng.addRoutes(route)
  147. ng.use(func(next http.HandlerFunc) http.HandlerFunc {
  148. return func(w http.ResponseWriter, r *http.Request) {
  149. next.ServeHTTP(w, r)
  150. }
  151. })
  152. assert.NotNil(t, ng.start(mockedRouter{}))
  153. }
  154. }
  155. }
  156. func TestEngine_checkedTimeout(t *testing.T) {
  157. tests := []struct {
  158. name string
  159. timeout time.Duration
  160. expect time.Duration
  161. }{
  162. {
  163. name: "not set",
  164. expect: time.Second,
  165. },
  166. {
  167. name: "less",
  168. timeout: time.Millisecond * 500,
  169. expect: time.Millisecond * 500,
  170. },
  171. {
  172. name: "equal",
  173. timeout: time.Second,
  174. expect: time.Second,
  175. },
  176. {
  177. name: "more",
  178. timeout: time.Millisecond * 1500,
  179. expect: time.Millisecond * 1500,
  180. },
  181. }
  182. ng := newEngine(RestConf{
  183. Timeout: 1000,
  184. })
  185. for _, test := range tests {
  186. assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout))
  187. }
  188. }
  189. func TestEngine_checkedMaxBytes(t *testing.T) {
  190. tests := []struct {
  191. name string
  192. maxBytes int64
  193. expect int64
  194. }{
  195. {
  196. name: "not set",
  197. expect: 1000,
  198. },
  199. {
  200. name: "less",
  201. maxBytes: 500,
  202. expect: 500,
  203. },
  204. {
  205. name: "equal",
  206. maxBytes: 1000,
  207. expect: 1000,
  208. },
  209. {
  210. name: "more",
  211. maxBytes: 1500,
  212. expect: 1500,
  213. },
  214. }
  215. ng := newEngine(RestConf{
  216. MaxBytes: 1000,
  217. })
  218. for _, test := range tests {
  219. assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
  220. }
  221. }
  222. func TestEngine_notFoundHandler(t *testing.T) {
  223. logx.Disable()
  224. ng := newEngine(RestConf{})
  225. ts := httptest.NewServer(ng.notFoundHandler(nil))
  226. defer ts.Close()
  227. client := ts.Client()
  228. err := func(ctx context.Context) error {
  229. req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
  230. assert.Nil(t, err)
  231. res, err := client.Do(req)
  232. assert.Nil(t, err)
  233. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  234. return res.Body.Close()
  235. }(context.Background())
  236. assert.Nil(t, err)
  237. }
  238. func TestEngine_notFoundHandlerNotNil(t *testing.T) {
  239. logx.Disable()
  240. ng := newEngine(RestConf{})
  241. var called int32
  242. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  243. atomic.AddInt32(&called, 1)
  244. })))
  245. defer ts.Close()
  246. client := ts.Client()
  247. err := func(ctx context.Context) error {
  248. req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
  249. assert.Nil(t, err)
  250. res, err := client.Do(req)
  251. assert.Nil(t, err)
  252. assert.Equal(t, http.StatusNotFound, res.StatusCode)
  253. return res.Body.Close()
  254. }(context.Background())
  255. assert.Nil(t, err)
  256. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  257. }
  258. func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
  259. logx.Disable()
  260. ng := newEngine(RestConf{})
  261. var called int32
  262. ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  263. atomic.AddInt32(&called, 1)
  264. w.WriteHeader(http.StatusExpectationFailed)
  265. })))
  266. defer ts.Close()
  267. client := ts.Client()
  268. err := func(ctx context.Context) error {
  269. req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
  270. assert.Nil(t, err)
  271. res, err := client.Do(req)
  272. assert.Nil(t, err)
  273. assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
  274. return res.Body.Close()
  275. }(context.Background())
  276. assert.Nil(t, err)
  277. assert.Equal(t, int32(1), atomic.LoadInt32(&called))
  278. }
  279. type mockedRouter struct{}
  280. func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  281. }
  282. func (m mockedRouter) Handle(method, path string, handler http.Handler) error {
  283. return errors.New("foo")
  284. }
  285. func (m mockedRouter) SetNotFoundHandler(handler http.Handler) {
  286. }
  287. func (m mockedRouter) SetNotAllowedHandler(handler http.Handler) {
  288. }