Selaa lähdekoodia

feat: refactor gateway code (#3160)

MarkJoyMa 2 vuotta sitten
vanhempi
sitoutus
9970ff55cd
2 muutettua tiedostoa jossa 58 lisäystä ja 64 poistoa
  1. 22 40
      gateway/server.go
  2. 36 24
      gateway/server_test.go

+ 22 - 40
gateway/server.go

@@ -21,26 +21,21 @@ import (
 type (
 	// Server is a gateway server.
 	Server struct {
-		c GatewayConf
 		*rest.Server
-		upstreams     []*upstream
+		upstreams     []Upstream
 		processHeader func(http.Header) []string
+		dialer        func(conf zrpc.RpcClientConf) zrpc.Client
 	}
 
 	// Option defines the method to customize Server.
 	Option func(svr *Server)
-
-	upstream struct {
-		Upstream
-		client zrpc.Client
-	}
 )
 
 // MustNewServer creates a new gateway server.
 func MustNewServer(c GatewayConf, opts ...Option) *Server {
 	svr := &Server{
-		c:      c,
-		Server: rest.MustNewServer(c.RestConf),
+		upstreams: c.Upstreams,
+		Server:    rest.MustNewServer(c.RestConf),
 	}
 	for _, opt := range opts {
 		opt(svr)
@@ -61,23 +56,15 @@ func (s *Server) Stop() {
 }
 
 func (s *Server) build() error {
-	if err := s.buildClient(); err != nil {
-		return err
-	}
-
-	return s.buildUpstream()
-}
-
-func (s *Server) buildClient() error {
 	if err := s.ensureUpstreamNames(); err != nil {
 		return err
 	}
 
 	return mr.MapReduceVoid(func(source chan<- Upstream) {
-		for _, up := range s.c.Upstreams {
+		for _, up := range s.upstreams {
 			source <- up
 		}
-	}, func(up Upstream, writer mr.Writer[*upstream], cancel func(error)) {
+	}, func(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) {
 		target, err := up.Grpc.BuildTarget()
 		if err != nil {
 			cancel(err)
@@ -85,26 +72,14 @@ func (s *Server) buildClient() error {
 		}
 
 		up.Name = target
-		cli := zrpc.MustNewClient(up.Grpc)
-		writer.Write(&upstream{
-			Upstream: up,
-			client:   cli,
-		})
-	}, func(pipe <-chan *upstream, cancel func(error)) {
-		for up := range pipe {
-			s.upstreams = append(s.upstreams, up)
+		var cli zrpc.Client
+		if s.dialer != nil {
+			cli = s.dialer(up.Grpc)
+		} else {
+			cli = zrpc.MustNewClient(up.Grpc)
 		}
-	})
-}
 
-func (s *Server) buildUpstream() error {
-	return mr.MapReduceVoid(func(source chan<- *upstream) {
-		for _, up := range s.upstreams {
-			source <- up
-		}
-	}, func(up *upstream, writer mr.Writer[rest.Route], cancel func(error)) {
-		cli := up.client
-		source, err := s.createDescriptorSource(cli, up.Upstream)
+		source, err := s.createDescriptorSource(cli, up)
 		if err != nil {
 			cancel(fmt.Errorf("%s: %w", up.Name, err))
 			return
@@ -191,13 +166,13 @@ func (s *Server) createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.D
 }
 
 func (s *Server) ensureUpstreamNames() error {
-	for _, up := range s.c.Upstreams {
-		target, err := up.Grpc.BuildTarget()
+	for i := 0; i < len(s.upstreams); i++ {
+		target, err := s.upstreams[i].Grpc.BuildTarget()
 		if err != nil {
 			return err
 		}
 
-		up.Name = target
+		s.upstreams[i].Name = target
 	}
 
 	return nil
@@ -219,3 +194,10 @@ func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server)
 		s.processHeader = processHeader
 	}
 }
+
+// withDialer sets a dialer to create a gRPC client.
+func withDialer(dialer func(conf zrpc.RpcClientConf) zrpc.Client) func(*Server) {
+	return func(s *Server) {
+		s.dialer = dialer
+	}
+}

+ 36 - 24
gateway/server_test.go

@@ -49,39 +49,36 @@ func TestMustNewServer(t *testing.T) {
 	c.Host = "localhost"
 	c.Port = 18881
 
-	s := MustNewServer(c)
-	s.upstreams = []*upstream{
+	s := MustNewServer(c, withDialer(func(conf zrpc.RpcClientConf) zrpc.Client {
+		return zrpc.MustNewClient(conf, zrpc.WithDialOption(grpc.WithContextDialer(dialer())))
+	}))
+	s.upstreams = []Upstream{
 		{
-			Upstream: Upstream{
-				Mappings: []RouteMapping{
-					{
-						Method:  "get",
-						Path:    "/deposit/:amount",
-						RpcPath: "mock.DepositService/Deposit",
-					},
+			Mappings: []RouteMapping{
+				{
+					Method:  "get",
+					Path:    "/deposit/:amount",
+					RpcPath: "mock.DepositService/Deposit",
 				},
 			},
-			client: zrpc.MustNewClient(
-				zrpc.RpcClientConf{
-					Endpoints: []string{"foo"},
-					Timeout:   1000,
-					Middlewares: zrpc.ClientMiddlewaresConf{
-						Trace:      true,
-						Duration:   true,
-						Prometheus: true,
-						Breaker:    true,
-						Timeout:    true,
-					},
+			Grpc: zrpc.RpcClientConf{
+				Endpoints: []string{"foo"},
+				Timeout:   1000,
+				Middlewares: zrpc.ClientMiddlewaresConf{
+					Trace:      true,
+					Duration:   true,
+					Prometheus: true,
+					Breaker:    true,
+					Timeout:    true,
 				},
-				zrpc.WithDialOption(grpc.WithContextDialer(dialer())),
-			),
+			},
 		},
 	}
 
-	assert.NoError(t, s.buildUpstream())
+	assert.NoError(t, s.build())
 	go s.Server.Start()
 
-	time.Sleep(time.Millisecond * 100)
+	time.Sleep(time.Millisecond * 200)
 
 	resp, err := httpc.Do(context.Background(), http.MethodGet, "http://localhost:18881/deposit/100", nil)
 	assert.NoError(t, err)
@@ -91,3 +88,18 @@ func TestMustNewServer(t *testing.T) {
 	assert.NoError(t, err)
 	assert.Equal(t, http.StatusNotFound, resp.StatusCode)
 }
+
+func TestServer_ensureUpstreamNames(t *testing.T) {
+	var s = Server{
+		upstreams: []Upstream{
+			{
+				Grpc: zrpc.RpcClientConf{
+					Target: "target",
+				},
+			},
+		},
+	}
+
+	assert.NoError(t, s.ensureUpstreamNames())
+	assert.Equal(t, "target", s.upstreams[0].Name)
+}