123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453 |
- package rest
- import (
- "context"
- "crypto/tls"
- "errors"
- "fmt"
- "github.com/wuntsong-org/go-zero-plus/rest/httpx"
- "net/http"
- "net/http/httptest"
- "os"
- "sync/atomic"
- "testing"
- "time"
- "github.com/stretchr/testify/assert"
- "github.com/wuntsong-org/go-zero-plus/core/conf"
- "github.com/wuntsong-org/go-zero-plus/core/fs"
- "github.com/wuntsong-org/go-zero-plus/core/logx"
- "github.com/wuntsong-org/go-zero-plus/rest/router"
- )
- const (
- priKey = `-----BEGIN RSA PRIVATE KEY-----
- MIICXQIBAAKBgQC4TJk3onpqb2RYE3wwt23J9SHLFstHGSkUYFLe+nl1dEKHbD+/
- Zt95L757J3xGTrwoTc7KCTxbrgn+stn0w52BNjj/kIE2ko4lbh/v8Fl14AyVR9ms
- fKtKOnhe5FCT72mdtApr+qvzcC3q9hfXwkyQU32pv7q5UimZ205iKSBmgQIDAQAB
- AoGAM5mWqGIAXj5z3MkP01/4CDxuyrrGDVD5FHBno3CDgyQa4Gmpa4B0/ywj671B
- aTnwKmSmiiCN2qleuQYASixes2zY5fgTzt+7KNkl9JHsy7i606eH2eCKzsUa/s6u
- WD8V3w/hGCQ9zYI18ihwyXlGHIgcRz/eeRh+nWcWVJzGOPUCQQD5nr6It/1yHb1p
- C6l4fC4xXF19l4KxJjGu1xv/sOpSx0pOqBDEX3Mh//FU954392rUWDXV1/I65BPt
- TLphdsu3AkEAvQJ2Qay/lffFj9FaUrvXuftJZ/Ypn0FpaSiUh3Ak3obBT6UvSZS0
- bcYdCJCNHDtBOsWHnIN1x+BcWAPrdU7PhwJBAIQ0dUlH2S3VXnoCOTGc44I1Hzbj
- Rc65IdsuBqA3fQN2lX5vOOIog3vgaFrOArg1jBkG1wx5IMvb/EnUN2pjVqUCQCza
- KLXtCInOAlPemlCHwumfeAvznmzsWNdbieOZ+SXVVIpR6KbNYwOpv7oIk3Pfm9sW
- hNffWlPUKhW42Gc+DIECQQDmk20YgBXwXWRM5DRPbhisIV088N5Z58K9DtFWkZsd
- OBDT3dFcgZONtlmR1MqZO0pTh30lA4qovYj3Bx7A8i36
- -----END RSA PRIVATE KEY-----`
- )
- func TestNewEngine(t *testing.T) {
- priKeyfile, err := fs.TempFilenameWithText(priKey)
- assert.Nil(t, err)
- defer os.Remove(priKeyfile)
- yamls := []string{
- `Name: foo
- Host: localhost
- Port: 0
- Middlewares:
- Log: false
- `,
- `Name: foo
- Host: localhost
- Port: 0
- CpuThreshold: 500
- Middlewares:
- Log: false
- `,
- `Name: foo
- Host: localhost
- Port: 0
- CpuThreshold: 500
- Verbose: true
- `,
- }
- routes := []featuredRoutes{
- {
- jwt: jwtSetting{},
- signature: signatureSetting{},
- routes: []Route{{
- Method: http.MethodGet,
- Path: "/",
- Handler: func(w http.ResponseWriter, r *http.Request) {},
- }},
- timeout: time.Minute,
- },
- {
- priority: true,
- jwt: jwtSetting{},
- signature: signatureSetting{},
- routes: []Route{{
- Method: http.MethodGet,
- Path: "/",
- Handler: func(w http.ResponseWriter, r *http.Request) {},
- }},
- timeout: time.Second,
- },
- {
- priority: true,
- jwt: jwtSetting{
- enabled: true,
- },
- signature: signatureSetting{},
- routes: []Route{{
- Method: http.MethodGet,
- Path: "/",
- Handler: func(w http.ResponseWriter, r *http.Request) {},
- }},
- },
- {
- priority: true,
- jwt: jwtSetting{
- enabled: true,
- prevSecret: "thesecret",
- },
- signature: signatureSetting{},
- routes: []Route{{
- Method: http.MethodGet,
- Path: "/",
- Handler: func(w http.ResponseWriter, r *http.Request) {},
- }},
- },
- {
- priority: true,
- jwt: jwtSetting{
- enabled: true,
- },
- signature: signatureSetting{},
- routes: []Route{{
- Method: http.MethodGet,
- Path: "/",
- Handler: func(w http.ResponseWriter, r *http.Request) {},
- }},
- },
- {
- priority: true,
- jwt: jwtSetting{
- enabled: true,
- },
- signature: signatureSetting{
- enabled: true,
- },
- routes: []Route{{
- Method: http.MethodGet,
- Path: "/",
- Handler: func(w http.ResponseWriter, r *http.Request) {},
- }},
- },
- {
- priority: true,
- jwt: jwtSetting{
- enabled: true,
- },
- signature: signatureSetting{
- enabled: true,
- SignatureConf: SignatureConf{
- Strict: true,
- },
- },
- routes: []Route{{
- Method: http.MethodGet,
- Path: "/",
- Handler: func(w http.ResponseWriter, r *http.Request) {},
- }},
- },
- {
- priority: true,
- jwt: jwtSetting{
- enabled: true,
- },
- signature: signatureSetting{
- enabled: true,
- SignatureConf: SignatureConf{
- Strict: true,
- PrivateKeys: []PrivateKeyConf{
- {
- Fingerprint: "a",
- KeyFile: "b",
- },
- },
- },
- },
- routes: []Route{{
- Method: http.MethodGet,
- Path: "/",
- Handler: func(w http.ResponseWriter, r *http.Request) {},
- }},
- },
- {
- priority: true,
- jwt: jwtSetting{
- enabled: true,
- },
- signature: signatureSetting{
- enabled: true,
- SignatureConf: SignatureConf{
- Strict: true,
- PrivateKeys: []PrivateKeyConf{
- {
- Fingerprint: "a",
- KeyFile: priKeyfile,
- },
- },
- },
- },
- routes: []Route{{
- Method: http.MethodGet,
- Path: "/",
- Handler: func(w http.ResponseWriter, r *http.Request) {},
- }},
- },
- }
- var index int32
- for _, yaml := range yamls {
- yaml := yaml
- for _, route := range routes {
- route := route
- t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
- var cnf RestConf
- assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
- ng := newEngine(cnf)
- if atomic.AddInt32(&index, 1)%2 == 0 {
- ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
- next http.Handler, strict bool, code int) {
- })
- }
- ng.addRoutes(route)
- ng.use(func(next http.HandlerFunc) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- next.ServeHTTP(w, r)
- }
- })
- assert.NotNil(t, ng.start(nil, mockedRouter{}, func(svr *http.Server) {
- }))
- timeout := time.Second * 3
- if route.timeout > timeout {
- timeout = route.timeout
- }
- assert.Equal(t, timeout, ng.timeout)
- })
- }
- }
- }
- func TestEngine_checkedTimeout(t *testing.T) {
- tests := []struct {
- name string
- timeout time.Duration
- expect time.Duration
- }{
- {
- name: "not set",
- expect: time.Second,
- },
- {
- name: "less",
- timeout: time.Millisecond * 500,
- expect: time.Millisecond * 500,
- },
- {
- name: "equal",
- timeout: time.Second,
- expect: time.Second,
- },
- {
- name: "more",
- timeout: time.Millisecond * 1500,
- expect: time.Millisecond * 1500,
- },
- }
- ng := newEngine(RestConf{
- Timeout: 1000,
- })
- for _, test := range tests {
- assert.Equal(t, test.expect, ng.checkedTimeout(test.timeout))
- }
- }
- func TestEngine_checkedMaxBytes(t *testing.T) {
- tests := []struct {
- name string
- maxBytes int64
- expect int64
- }{
- {
- name: "not set",
- expect: 1000,
- },
- {
- name: "less",
- maxBytes: 500,
- expect: 500,
- },
- {
- name: "equal",
- maxBytes: 1000,
- expect: 1000,
- },
- {
- name: "more",
- maxBytes: 1500,
- expect: 1500,
- },
- }
- ng := newEngine(RestConf{
- MaxBytes: 1000,
- })
- for _, test := range tests {
- assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
- }
- }
- func TestEngine_notFoundHandler(t *testing.T) {
- logx.Disable()
- ng := newEngine(RestConf{})
- ts := httptest.NewServer(ng.notFoundHandler(nil))
- defer ts.Close()
- client := ts.Client()
- err := func(_ context.Context) error {
- req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
- assert.Nil(t, err)
- res, err := client.Do(req)
- assert.Nil(t, err)
- assert.Equal(t, http.StatusNotFound, res.StatusCode)
- return res.Body.Close()
- }(context.Background())
- assert.Nil(t, err)
- }
- func TestEngine_notFoundHandlerNotNil(t *testing.T) {
- logx.Disable()
- ng := newEngine(RestConf{})
- var called int32
- ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- atomic.AddInt32(&called, 1)
- })))
- defer ts.Close()
- client := ts.Client()
- err := func(_ context.Context) error {
- req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
- assert.Nil(t, err)
- res, err := client.Do(req)
- assert.Nil(t, err)
- assert.Equal(t, http.StatusNotFound, res.StatusCode)
- return res.Body.Close()
- }(context.Background())
- assert.Nil(t, err)
- assert.Equal(t, int32(1), atomic.LoadInt32(&called))
- }
- func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
- logx.Disable()
- ng := newEngine(RestConf{})
- var called int32
- ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- atomic.AddInt32(&called, 1)
- w.WriteHeader(http.StatusExpectationFailed)
- })))
- defer ts.Close()
- client := ts.Client()
- err := func(_ context.Context) error {
- req, err := http.NewRequest("GET", ts.URL+"/bad", http.NoBody)
- assert.Nil(t, err)
- res, err := client.Do(req)
- assert.Nil(t, err)
- assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
- return res.Body.Close()
- }(context.Background())
- assert.Nil(t, err)
- assert.Equal(t, int32(1), atomic.LoadInt32(&called))
- }
- func TestEngine_withTimeout(t *testing.T) {
- logx.Disable()
- tests := []struct {
- name string
- timeout int64
- }{
- {
- name: "not set",
- },
- {
- name: "set",
- timeout: 1000,
- },
- }
- for _, test := range tests {
- test := test
- t.Run(test.name, func(t *testing.T) {
- ng := newEngine(RestConf{Timeout: test.timeout})
- svr := &http.Server{}
- ng.withTimeout()(svr)
- assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
- assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
- assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
- assert.Equal(t, time.Duration(0), svr.IdleTimeout)
- })
- }
- }
- func TestEngine_start(t *testing.T) {
- logx.Disable()
- t.Run("http", func(t *testing.T) {
- ng := newEngine(RestConf{
- Host: "localhost",
- Port: -1,
- })
- assert.Error(t, ng.start(nil, router.NewRouter()))
- })
- t.Run("https", func(t *testing.T) {
- ng := newEngine(RestConf{
- Host: "localhost",
- Port: -1,
- CertFile: "foo",
- KeyFile: "bar",
- })
- ng.tlsConfig = &tls.Config{}
- assert.Error(t, ng.start(nil, router.NewRouter()))
- })
- }
- type mockedRouter struct {
- }
- func (m mockedRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
- }
- func (m mockedRouter) Handle(_, _ string, handler http.Handler) error {
- return errors.New("foo")
- }
- func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
- }
- func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
- }
- func (m mockedRouter) SetOptionsHandler(_ http.Handler) {
- }
- func (m mockedRouter) SetMiddleware(_ httpx.MiddlewareFunc) {
- }
|