1
0

timeoutinterceptor.go 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. package serverinterceptors
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "runtime/debug"
  7. "strings"
  8. "sync"
  9. "time"
  10. "google.golang.org/grpc"
  11. "google.golang.org/grpc/codes"
  12. "google.golang.org/grpc/status"
  13. )
  14. type (
  15. // MethodTimeoutConf defines specified timeout for gRPC method.
  16. MethodTimeoutConf struct {
  17. FullMethod string
  18. Timeout time.Duration
  19. }
  20. methodTimeouts map[string]time.Duration
  21. )
  22. // UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
  23. func UnaryTimeoutInterceptor(timeout time.Duration,
  24. methodTimeouts ...MethodTimeoutConf) grpc.UnaryServerInterceptor {
  25. timeouts := buildMethodTimeouts(methodTimeouts)
  26. return func(ctx context.Context, req any, info *grpc.UnaryServerInfo,
  27. handler grpc.UnaryHandler) (any, error) {
  28. t := getTimeoutByUnaryServerInfo(info.FullMethod, timeouts, timeout)
  29. ctx, cancel := context.WithTimeout(ctx, t)
  30. defer cancel()
  31. var resp any
  32. var err error
  33. var lock sync.Mutex
  34. done := make(chan struct{})
  35. // create channel with buffer size 1 to avoid goroutine leak
  36. panicChan := make(chan any, 1)
  37. go func() {
  38. defer func() {
  39. if p := recover(); p != nil {
  40. // attach call stack to avoid missing in different goroutine
  41. panicChan <- fmt.Sprintf("%+v\n\n%s", p, strings.TrimSpace(string(debug.Stack())))
  42. }
  43. }()
  44. lock.Lock()
  45. defer lock.Unlock()
  46. resp, err = handler(ctx, req)
  47. close(done)
  48. }()
  49. select {
  50. case p := <-panicChan:
  51. panic(p)
  52. case <-done:
  53. lock.Lock()
  54. defer lock.Unlock()
  55. return resp, err
  56. case <-ctx.Done():
  57. err := ctx.Err()
  58. if errors.Is(err, context.Canceled) {
  59. err = status.Error(codes.Canceled, err.Error())
  60. } else if errors.Is(err, context.DeadlineExceeded) {
  61. err = status.Error(codes.DeadlineExceeded, err.Error())
  62. }
  63. return nil, err
  64. }
  65. }
  66. }
  67. func buildMethodTimeouts(timeouts []MethodTimeoutConf) methodTimeouts {
  68. mt := make(methodTimeouts, len(timeouts))
  69. for _, st := range timeouts {
  70. if st.FullMethod != "" {
  71. mt[st.FullMethod] = st.Timeout
  72. }
  73. }
  74. return mt
  75. }
  76. func getTimeoutByUnaryServerInfo(method string, timeouts methodTimeouts,
  77. defaultTimeout time.Duration) time.Duration {
  78. if v, ok := timeouts[method]; ok {
  79. return v
  80. }
  81. return defaultTimeout
  82. }