kevin 4 роки тому
батько
коміт
dc286a03f5

+ 62 - 0
rpcx/internal/auth/credential_test.go

@@ -0,0 +1,62 @@
+package auth
+
+import (
+	"context"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"google.golang.org/grpc/metadata"
+)
+
+func TestParseCredential(t *testing.T) {
+	tests := []struct {
+		name        string
+		withNil     bool
+		withEmptyMd bool
+		app         string
+		token       string
+	}{
+		{
+			name:    "nil",
+			withNil: true,
+		},
+		{
+			name:        "empty md",
+			withEmptyMd: true,
+		},
+		{
+			name: "empty",
+		},
+		{
+			name:  "valid",
+			app:   "foo",
+			token: "bar",
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			t.Parallel()
+
+			var ctx context.Context
+			if test.withNil {
+				ctx = context.Background()
+			} else if test.withEmptyMd {
+				ctx = metadata.NewIncomingContext(context.Background(), metadata.MD{})
+			} else {
+				md := metadata.New(map[string]string{
+					"app":   test.app,
+					"token": test.token,
+				})
+				ctx = metadata.NewIncomingContext(context.Background(), md)
+			}
+			cred := ParseCredential(ctx)
+			assert.False(t, cred.RequireTransportSecurity())
+			m, err := cred.GetRequestMetadata(context.Background())
+			assert.Nil(t, err)
+			assert.Equal(t, test.app, m[appKey])
+			assert.Equal(t, test.token, m[tokenKey])
+		})
+	}
+}

+ 154 - 9
rpcx/internal/serverinterceptors/authinterceptor_test.go

@@ -8,21 +8,108 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/tal-tech/go-zero/core/stores/redis"
 	"github.com/tal-tech/go-zero/rpcx/internal/auth"
+	"google.golang.org/grpc"
 	"google.golang.org/grpc/metadata"
 )
 
+func TestStreamAuthorizeInterceptor(t *testing.T) {
+	tests := []struct {
+		name     string
+		app      string
+		token    string
+		strict   bool
+		hasError bool
+	}{
+		{
+			name:     "strict=false",
+			strict:   false,
+			hasError: false,
+		},
+		{
+			name:     "strict=true",
+			strict:   true,
+			hasError: true,
+		},
+		{
+			name:     "strict=true,with token",
+			app:      "foo",
+			token:    "bar",
+			strict:   true,
+			hasError: false,
+		},
+		{
+			name:     "strict=true,with error token",
+			app:      "foo",
+			token:    "error",
+			strict:   true,
+			hasError: true,
+		},
+	}
+
+	r := miniredis.NewMiniRedis()
+	assert.Nil(t, r.Start())
+	defer r.Close()
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			store := redis.NewRedis(r.Addr(), redis.NodeType)
+			if len(test.app) > 0 {
+				assert.Nil(t, store.Hset("apps", test.app, test.token))
+				defer store.Hdel("apps", test.app)
+			}
+
+			authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
+			assert.Nil(t, err)
+			interceptor := StreamAuthorizeInterceptor(authenticator)
+			md := metadata.New(map[string]string{
+				"app":   "foo",
+				"token": "bar",
+			})
+			ctx := metadata.NewIncomingContext(context.Background(), md)
+			stream := mockedStream{ctx: ctx}
+			err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
+				return nil
+			})
+			if test.hasError {
+				assert.NotNil(t, err)
+			} else {
+				assert.Nil(t, err)
+			}
+		})
+	}
+}
+
 func TestUnaryAuthorizeInterceptor(t *testing.T) {
 	tests := []struct {
-		name   string
-		strict bool
+		name     string
+		app      string
+		token    string
+		strict   bool
+		hasError bool
 	}{
 		{
-			name:   "strict=true",
-			strict: true,
+			name:     "strict=false",
+			strict:   false,
+			hasError: false,
+		},
+		{
+			name:     "strict=true",
+			strict:   true,
+			hasError: true,
+		},
+		{
+			name:     "strict=true,with token",
+			app:      "foo",
+			token:    "bar",
+			strict:   true,
+			hasError: false,
 		},
 		{
-			name:   "strict=false",
-			strict: false,
+			name:     "strict=true,with error token",
+			app:      "foo",
+			token:    "error",
+			strict:   true,
+			hasError: true,
 		},
 	}
 
@@ -33,23 +120,81 @@ func TestUnaryAuthorizeInterceptor(t *testing.T) {
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
 			store := redis.NewRedis(r.Addr(), redis.NodeType)
+			if len(test.app) > 0 {
+				assert.Nil(t, store.Hset("apps", test.app, test.token))
+				defer store.Hdel("apps", test.app)
+			}
+
 			authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
 			assert.Nil(t, err)
 			interceptor := UnaryAuthorizeInterceptor(authenticator)
 			md := metadata.New(map[string]string{
-				"app":   "name",
-				"token": "key",
+				"app":   "foo",
+				"token": "bar",
 			})
 			ctx := metadata.NewIncomingContext(context.Background(), md)
 			_, err = interceptor(ctx, nil, nil,
 				func(ctx context.Context, req interface{}) (interface{}, error) {
 					return nil, nil
 				})
-			if test.strict {
+			if test.hasError {
 				assert.NotNil(t, err)
 			} else {
 				assert.Nil(t, err)
 			}
+			if test.strict {
+				_, err = interceptor(context.Background(), nil, nil,
+					func(ctx context.Context, req interface{}) (interface{}, error) {
+						return nil, nil
+					})
+				assert.NotNil(t, err)
+
+				var md metadata.MD
+				ctx := metadata.NewIncomingContext(context.Background(), md)
+				_, err = interceptor(ctx, nil, nil,
+					func(ctx context.Context, req interface{}) (interface{}, error) {
+						return nil, nil
+					})
+				assert.NotNil(t, err)
+
+				md = metadata.New(map[string]string{
+					"app":   "",
+					"token": "",
+				})
+				ctx = metadata.NewIncomingContext(context.Background(), md)
+				_, err = interceptor(ctx, nil, nil,
+					func(ctx context.Context, req interface{}) (interface{}, error) {
+						return nil, nil
+					})
+				assert.NotNil(t, err)
+			}
 		})
 	}
 }
+
+type mockedStream struct {
+	ctx context.Context
+}
+
+func (m mockedStream) SetHeader(md metadata.MD) error {
+	return nil
+}
+
+func (m mockedStream) SendHeader(md metadata.MD) error {
+	return nil
+}
+
+func (m mockedStream) SetTrailer(md metadata.MD) {
+}
+
+func (m mockedStream) Context() context.Context {
+	return m.ctx
+}
+
+func (m mockedStream) SendMsg(v interface{}) error {
+	return nil
+}
+
+func (m mockedStream) RecvMsg(v interface{}) error {
+	return nil
+}