authinterceptor_test.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "testing"
  5. "github.com/stretchr/testify/assert"
  6. "github.com/zeromicro/go-zero/core/stores/redis/redistest"
  7. "github.com/zeromicro/go-zero/zrpc/internal/auth"
  8. "google.golang.org/grpc"
  9. "google.golang.org/grpc/metadata"
  10. )
  11. func TestStreamAuthorizeInterceptor(t *testing.T) {
  12. tests := []struct {
  13. name string
  14. app string
  15. token string
  16. strict bool
  17. hasError bool
  18. }{
  19. {
  20. name: "strict=false",
  21. strict: false,
  22. hasError: false,
  23. },
  24. {
  25. name: "strict=true",
  26. strict: true,
  27. hasError: true,
  28. },
  29. {
  30. name: "strict=true,with token",
  31. app: "foo",
  32. token: "bar",
  33. strict: true,
  34. hasError: false,
  35. },
  36. {
  37. name: "strict=true,with error token",
  38. app: "foo",
  39. token: "error",
  40. strict: true,
  41. hasError: true,
  42. },
  43. }
  44. store := redistest.CreateRedis(t)
  45. for _, test := range tests {
  46. t.Run(test.name, func(t *testing.T) {
  47. if len(test.app) > 0 {
  48. assert.Nil(t, store.Hset("apps", test.app, test.token))
  49. defer store.Hdel("apps", test.app)
  50. }
  51. authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
  52. assert.Nil(t, err)
  53. interceptor := StreamAuthorizeInterceptor(authenticator)
  54. md := metadata.New(map[string]string{
  55. "app": "foo",
  56. "token": "bar",
  57. })
  58. ctx := metadata.NewIncomingContext(context.Background(), md)
  59. stream := mockedStream{ctx: ctx}
  60. err = interceptor(nil, stream, nil, func(_ any, _ grpc.ServerStream) error {
  61. return nil
  62. })
  63. if test.hasError {
  64. assert.NotNil(t, err)
  65. } else {
  66. assert.Nil(t, err)
  67. }
  68. })
  69. }
  70. }
  71. func TestUnaryAuthorizeInterceptor(t *testing.T) {
  72. tests := []struct {
  73. name string
  74. app string
  75. token string
  76. strict bool
  77. hasError bool
  78. }{
  79. {
  80. name: "strict=false",
  81. strict: false,
  82. hasError: false,
  83. },
  84. {
  85. name: "strict=true",
  86. strict: true,
  87. hasError: true,
  88. },
  89. {
  90. name: "strict=true,with token",
  91. app: "foo",
  92. token: "bar",
  93. strict: true,
  94. hasError: false,
  95. },
  96. {
  97. name: "strict=true,with error token",
  98. app: "foo",
  99. token: "error",
  100. strict: true,
  101. hasError: true,
  102. },
  103. }
  104. store := redistest.CreateRedis(t)
  105. for _, test := range tests {
  106. t.Run(test.name, func(t *testing.T) {
  107. if len(test.app) > 0 {
  108. assert.Nil(t, store.Hset("apps", test.app, test.token))
  109. defer store.Hdel("apps", test.app)
  110. }
  111. authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
  112. assert.Nil(t, err)
  113. interceptor := UnaryAuthorizeInterceptor(authenticator)
  114. md := metadata.New(map[string]string{
  115. "app": "foo",
  116. "token": "bar",
  117. })
  118. ctx := metadata.NewIncomingContext(context.Background(), md)
  119. _, err = interceptor(ctx, nil, nil,
  120. func(ctx context.Context, req any) (any, error) {
  121. return nil, nil
  122. })
  123. if test.hasError {
  124. assert.NotNil(t, err)
  125. } else {
  126. assert.Nil(t, err)
  127. }
  128. if test.strict {
  129. _, err = interceptor(context.Background(), nil, nil,
  130. func(ctx context.Context, req any) (any, error) {
  131. return nil, nil
  132. })
  133. assert.NotNil(t, err)
  134. var md metadata.MD
  135. ctx := metadata.NewIncomingContext(context.Background(), md)
  136. _, err = interceptor(ctx, nil, nil,
  137. func(ctx context.Context, req any) (any, error) {
  138. return nil, nil
  139. })
  140. assert.NotNil(t, err)
  141. md = metadata.New(map[string]string{
  142. "app": "",
  143. "token": "",
  144. })
  145. ctx = metadata.NewIncomingContext(context.Background(), md)
  146. _, err = interceptor(ctx, nil, nil,
  147. func(ctx context.Context, req any) (any, error) {
  148. return nil, nil
  149. })
  150. assert.NotNil(t, err)
  151. }
  152. })
  153. }
  154. }
  155. type mockedStream struct {
  156. ctx context.Context
  157. }
  158. func (m mockedStream) SetHeader(md metadata.MD) error {
  159. return nil
  160. }
  161. func (m mockedStream) SendHeader(md metadata.MD) error {
  162. return nil
  163. }
  164. func (m mockedStream) SetTrailer(md metadata.MD) {
  165. }
  166. func (m mockedStream) Context() context.Context {
  167. return m.ctx
  168. }
  169. func (m mockedStream) SendMsg(v any) error {
  170. return nil
  171. }
  172. func (m mockedStream) RecvMsg(v any) error {
  173. return nil
  174. }