1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- package rpc
- import (
- "context"
- "google.golang.org/grpc"
- )
- func WithStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.DialOption {
- return grpc.WithStreamInterceptor(chainStreamClientInterceptors(interceptors...))
- }
- func WithUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.DialOption {
- return grpc.WithUnaryInterceptor(chainUnaryClientInterceptors(interceptors...))
- }
- func chainStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
- switch len(interceptors) {
- case 0:
- return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
- streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
- return streamer(ctx, desc, cc, method, opts...)
- }
- case 1:
- return interceptors[0]
- default:
- last := len(interceptors) - 1
- return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn,
- method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
- var chainStreamer grpc.Streamer
- var current int
- chainStreamer = func(curCtx context.Context, curDesc *grpc.StreamDesc, curCc *grpc.ClientConn,
- curMethod string, curOpts ...grpc.CallOption) (grpc.ClientStream, error) {
- if current == last {
- return streamer(curCtx, curDesc, curCc, curMethod, curOpts...)
- }
- current++
- clientStream, err := interceptors[current](curCtx, curDesc, curCc, curMethod, chainStreamer, curOpts...)
- current--
- return clientStream, err
- }
- return interceptors[0](ctx, desc, cc, method, chainStreamer, opts...)
- }
- }
- }
- func chainUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
- switch len(interceptors) {
- case 0:
- return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
- invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
- return invoker(ctx, method, req, reply, cc, opts...)
- }
- case 1:
- return interceptors[0]
- default:
- last := len(interceptors) - 1
- return func(ctx context.Context, method string, req, reply interface{},
- cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
- var chainInvoker grpc.UnaryInvoker
- var current int
- chainInvoker = func(curCtx context.Context, curMethod string, curReq, curReply interface{},
- curCc *grpc.ClientConn, curOpts ...grpc.CallOption) error {
- if current == last {
- return invoker(curCtx, curMethod, curReq, curReply, curCc, curOpts...)
- }
- current++
- err := interceptors[current](curCtx, curMethod, curReq, curReply, curCc, chainInvoker, curOpts...)
- current--
- return err
- }
- return interceptors[0](ctx, method, req, reply, cc, chainInvoker, opts...)
- }
- }
- }
|