Browse Source

chore: refactor errors to use errors.Is (#3654)

Kevin Wan 1 year ago
parent
commit
42e0a6f90c

+ 10 - 10
core/breaker/breakers_test.go

@@ -30,7 +30,7 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
 		assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error {
 			return errDummy
 		}, func(err error) bool {
-			return err == nil || err == errDummy
+			return err == nil || errors.Is(err, errDummy)
 		}))
 	}
 	verify(t, func() bool {
@@ -45,12 +45,12 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
 		}, func(err error) bool {
 			return err == nil
 		})
-		assert.True(t, err == errDummy || err == ErrServiceUnavailable)
+		assert.True(t, errors.Is(err, errDummy) || errors.Is(err, ErrServiceUnavailable))
 	}
 	verify(t, func() bool {
-		return ErrServiceUnavailable == Do("another", func() error {
+		return errors.Is(Do("another", func() error {
 			return nil
-		})
+		}), ErrServiceUnavailable)
 	})
 }
 
@@ -75,12 +75,12 @@ func TestBreakersFallback(t *testing.T) {
 		}, func(err error) error {
 			return nil
 		})
-		assert.True(t, err == nil || err == errDummy)
+		assert.True(t, err == nil || errors.Is(err, errDummy))
 	}
 	verify(t, func() bool {
-		return ErrServiceUnavailable == Do("fallback", func() error {
+		return errors.Is(Do("fallback", func() error {
 			return nil
-		})
+		}), ErrServiceUnavailable)
 	})
 }
 
@@ -94,12 +94,12 @@ func TestBreakersAcceptableFallback(t *testing.T) {
 		}, func(err error) bool {
 			return err == nil
 		})
-		assert.True(t, err == nil || err == errDummy)
+		assert.True(t, err == nil || errors.Is(err, errDummy))
 	}
 	verify(t, func() bool {
-		return ErrServiceUnavailable == Do("acceptablefallback", func() error {
+		return errors.Is(Do("acceptablefallback", func() error {
 			return nil
-		})
+		}), ErrServiceUnavailable)
 	})
 }
 

+ 3 - 3
core/search/tree.go

@@ -69,10 +69,10 @@ func (t *Tree) Add(route string, item any) error {
 	}
 
 	err := add(t.root, route[1:], item)
-	switch err {
-	case errDupItem:
+	switch {
+	case errors.Is(err, errDupItem):
 		return duplicatedItem(route)
-	case errDupSlash:
+	case errors.Is(err, errDupSlash):
 		return duplicatedSlash(route)
 	default:
 		return err

+ 4 - 4
core/stores/cache/cachenode.go

@@ -96,7 +96,7 @@ func (c cacheNode) Get(key string, val any) error {
 // GetCtx gets the cache with key and fills into v.
 func (c cacheNode) GetCtx(ctx context.Context, key string, val any) error {
 	err := c.doGetCache(ctx, key, val)
-	if err == errPlaceholder {
+	if errors.Is(err, errPlaceholder) {
 		return c.errNotFound
 	}
 
@@ -210,16 +210,16 @@ func (c cacheNode) doTake(ctx context.Context, v any, key string,
 	logger := logx.WithContext(ctx)
 	val, fresh, err := c.barrier.DoEx(key, func() (any, error) {
 		if err := c.doGetCache(ctx, key, v); err != nil {
-			if err == errPlaceholder {
+			if errors.Is(err, errPlaceholder) {
 				return nil, c.errNotFound
-			} else if err != c.errNotFound {
+			} else if !errors.Is(err, c.errNotFound) {
 				// why we just return the error instead of query from db,
 				// because we don't allow the disaster pass to the dbs.
 				// fail fast, in case we bring down the dbs.
 				return nil, err
 			}
 
-			if err = query(v); err == c.errNotFound {
+			if err = query(v); errors.Is(err, c.errNotFound) {
 				if err = c.setCacheWithNotFound(ctx, key); err != nil {
 					logger.Error(err)
 				}

+ 15 - 6
core/stores/mon/collection.go

@@ -3,6 +3,7 @@ package mon
 import (
 	"context"
 	"encoding/json"
+	"errors"
 	"time"
 
 	"github.com/zeromicro/go-zero/core/breaker"
@@ -562,11 +563,19 @@ func (p keepablePromise) keep(err error) error {
 }
 
 func acceptable(err error) bool {
-	return err == nil || err == mongo.ErrNoDocuments || err == mongo.ErrNilValue ||
-		err == mongo.ErrNilDocument || err == mongo.ErrNilCursor || err == mongo.ErrEmptySlice ||
+	return err == nil ||
+		errors.Is(err, mongo.ErrNoDocuments) ||
+		errors.Is(err, mongo.ErrNilValue) ||
+		errors.Is(err, mongo.ErrNilDocument) ||
+		errors.Is(err, mongo.ErrNilCursor) ||
+		errors.Is(err, mongo.ErrEmptySlice) ||
 		// session errors
-		err == session.ErrSessionEnded || err == session.ErrNoTransactStarted ||
-		err == session.ErrTransactInProgress || err == session.ErrAbortAfterCommit ||
-		err == session.ErrAbortTwice || err == session.ErrCommitAfterAbort ||
-		err == session.ErrUnackWCUnsupported || err == session.ErrSnapshotTransaction
+		errors.Is(err, session.ErrSessionEnded) ||
+		errors.Is(err, session.ErrNoTransactStarted) ||
+		errors.Is(err, session.ErrTransactInProgress) ||
+		errors.Is(err, session.ErrAbortAfterCommit) ||
+		errors.Is(err, session.ErrAbortTwice) ||
+		errors.Is(err, session.ErrCommitAfterAbort) ||
+		errors.Is(err, session.ErrUnackWCUnsupported) ||
+		errors.Is(err, session.ErrSnapshotTransaction)
 }

+ 3 - 2
core/stores/mon/trace.go

@@ -2,6 +2,7 @@ package mon
 
 import (
 	"context"
+	"errors"
 
 	"github.com/zeromicro/go-zero/core/trace"
 	"go.mongodb.org/mongo-driver/mongo"
@@ -23,8 +24,8 @@ func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span
 func endSpan(span oteltrace.Span, err error) {
 	defer span.End()
 
-	if err == nil || err == mongo.ErrNoDocuments ||
-		err == mongo.ErrNilValue || err == mongo.ErrNilDocument {
+	if err == nil || errors.Is(err, mongo.ErrNoDocuments) ||
+		errors.Is(err, mongo.ErrNilValue) || errors.Is(err, mongo.ErrNilDocument) {
 		span.SetStatus(codes.Ok, "")
 		return
 	}

+ 1 - 1
core/stores/redis/redis.go

@@ -2849,7 +2849,7 @@ func withHook(hook red.Hook) Option {
 }
 
 func acceptable(err error) bool {
-	return err == nil || err == red.Nil || err == context.Canceled
+	return err == nil || err == red.Nil || errors.Is(err, context.Canceled)
 }
 
 func getRedis(r *Redis) (RedisNode, error) {

+ 7 - 2
core/stores/sqlx/mysql.go

@@ -1,6 +1,10 @@
 package sqlx
 
-import "github.com/go-sql-driver/mysql"
+import (
+	"errors"
+
+	"github.com/go-sql-driver/mysql"
+)
 
 const (
 	mysqlDriverName           = "mysql"
@@ -18,7 +22,8 @@ func mysqlAcceptable(err error) bool {
 		return true
 	}
 
-	myerr, ok := err.(*mysql.MySQLError)
+	var myerr *mysql.MySQLError
+	ok := errors.As(err, &myerr)
 	if !ok {
 		return false
 	}

+ 1 - 1
core/stores/sqlx/mysql_test.go

@@ -28,7 +28,7 @@ func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) {
 
 	var found bool
 	for i := 0; i < 100; i++ {
-		if tryOnDuplicateEntryError(t, nil) == breaker.ErrServiceUnavailable {
+		if errors.Is(tryOnDuplicateEntryError(t, nil), breaker.ErrServiceUnavailable) {
 			found = true
 		}
 	}

+ 10 - 7
core/stores/sqlx/sqlconn.go

@@ -3,6 +3,7 @@ package sqlx
 import (
 	"context"
 	"database/sql"
+	"errors"
 
 	"github.com/zeromicro/go-zero/core/breaker"
 	"github.com/zeromicro/go-zero/core/logx"
@@ -157,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
 		result, err = exec(ctx, conn, q, args...)
 		return err
 	}, db.acceptable)
-	if err == breaker.ErrServiceUnavailable {
+	if errors.Is(err, breaker.ErrServiceUnavailable) {
 		metricReqErr.Inc("Exec", "breaker")
 	}
 
@@ -193,7 +194,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
 		}
 		return nil
 	}, db.acceptable)
-	if err == breaker.ErrServiceUnavailable {
+	if errors.Is(err, breaker.ErrServiceUnavailable) {
 		metricReqErr.Inc("Prepare", "breaker")
 	}
 
@@ -283,7 +284,7 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
 	err = db.brk.DoWithAcceptable(func() error {
 		return transact(ctx, db, db.beginTx, fn)
 	}, db.acceptable)
-	if err == breaker.ErrServiceUnavailable {
+	if errors.Is(err, breaker.ErrServiceUnavailable) {
 		metricReqErr.Inc("Transact", "breaker")
 	}
 
@@ -291,11 +292,13 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
 }
 
 func (db *commonSqlConn) acceptable(err error) bool {
-	if err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled {
+	if err == nil || errors.Is(err, sql.ErrNoRows) || errors.Is(err, sql.ErrTxDone) ||
+		errors.Is(err, context.Canceled) {
 		return true
 	}
 
-	if _, ok := err.(acceptableError); ok {
+	var e acceptableError
+	if errors.As(err, &e) {
 		return true
 	}
 
@@ -321,9 +324,9 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
 			return qerr
 		}, q, args...)
 	}, func(err error) bool {
-		return qerr == err || db.acceptable(err)
+		return errors.Is(err, qerr) || db.acceptable(err)
 	})
-	if err == breaker.ErrServiceUnavailable {
+	if errors.Is(err, breaker.ErrServiceUnavailable) {
 		metricReqErr.Inc("queryRows", "breaker")
 	}
 

+ 1 - 1
core/stores/sqlx/utils.go

@@ -143,7 +143,7 @@ func logInstanceError(ctx context.Context, datasource string, err error) {
 }
 
 func logSqlError(ctx context.Context, stmt string, err error) {
-	if err != nil && err != ErrNotFound {
+	if err != nil && !errors.Is(err, ErrNotFound) {
 		logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
 	}
 }

+ 1 - 1
core/syncx/lockedcalls_test.go

@@ -27,7 +27,7 @@ func TestLockedCallDoErr(t *testing.T) {
 	v, err := g.Do("key", func() (any, error) {
 		return nil, someErr
 	})
-	if err != someErr {
+	if !errors.Is(err, someErr) {
 		t.Errorf("Do error = %v; want someErr", err)
 	}
 	if v != nil {

+ 1 - 1
core/syncx/singleflight_test.go

@@ -28,7 +28,7 @@ func TestExclusiveCallDoErr(t *testing.T) {
 	v, err := g.Do("key", func() (any, error) {
 		return nil, someErr
 	})
-	if err != someErr {
+	if !errors.Is(err, someErr) {
 		t.Errorf("Do error = %v; want someErr", err)
 	}
 	if v != nil {

+ 6 - 5
rest/httpx/responses.go

@@ -3,6 +3,7 @@ package httpx
 import (
 	"context"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"net/http"
 	"sync"
@@ -141,10 +142,10 @@ func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, a
 		return
 	}
 
-	e, ok := body.(error)
-	if ok {
-		http.Error(w, e.Error(), code)
-	} else {
+	switch v := body.(type) {
+	case error:
+		http.Error(w, v.Error(), code)
+	default:
 		writeJson(w, code, body)
 	}
 }
@@ -162,7 +163,7 @@ func doWriteJson(w http.ResponseWriter, code int, v any) error {
 	if n, err := w.Write(bs); err != nil {
 		// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
 		// so it's ignored here.
-		if err != http.ErrHandlerTimeout {
+		if !errors.Is(err, http.ErrHandlerTimeout) {
 			return fmt.Errorf("write response failed, error: %w", err)
 		}
 	} else if n < len(bs) {

+ 2 - 1
rest/internal/starter.go

@@ -2,6 +2,7 @@ package internal
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"net/http"
 
@@ -49,7 +50,7 @@ func start(host string, port int, handler http.Handler, run func(svr *http.Serve
 		}
 	})
 	defer func() {
-		if err == http.ErrServerClosed {
+		if errors.Is(err, http.ErrServerClosed) {
 			waitForCalled()
 		}
 	}()

+ 2 - 1
rest/server.go

@@ -2,6 +2,7 @@ package rest
 
 import (
 	"crypto/tls"
+	"errors"
 	"net/http"
 	"path"
 	"time"
@@ -307,7 +308,7 @@ func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
 
 func handleError(err error) {
 	// ErrServerClosed means the server is closed manually
-	if err == nil || err == http.ErrServerClosed {
+	if err == nil || errors.Is(err, http.ErrServerClosed) {
 		return
 	}
 

+ 4 - 2
tools/goctl/model/cmd.go

@@ -56,7 +56,8 @@ func init() {
 	pgDatasourceCmdFlags.StringVar(&command.VarStringHome, "home")
 	pgDatasourceCmdFlags.StringVar(&command.VarStringRemote, "remote")
 	pgDatasourceCmdFlags.StringVar(&command.VarStringBranch, "branch")
-	pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns, "ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
+	pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns,
+		"ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
 
 	mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t")
 	mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c")
@@ -68,7 +69,8 @@ func init() {
 	mongoCmdFlags.StringVar(&mongo.VarStringBranch, "branch")
 
 	mysqlCmd.PersistentFlags().BoolVar(&command.VarBoolStrict, "strict")
-	mysqlCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns, "ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
+	mysqlCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns,
+		"ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
 
 	mysqlCmd.AddCommand(datasourceCmd, ddlCmd)
 	pgCmd.AddCommand(pgDatasourceCmd)

+ 2 - 1
zrpc/internal/codes/accept.go

@@ -8,7 +8,8 @@ import (
 // Acceptable checks if given error is acceptable.
 func Acceptable(err error) bool {
 	switch status.Code(err) {
-	case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss, codes.Unimplemented:
+	case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss,
+		codes.Unimplemented, codes.ResourceExhausted:
 		return false
 	default:
 		return true

+ 6 - 0
zrpc/internal/serverinterceptors/breakerinterceptor.go

@@ -2,10 +2,13 @@ package serverinterceptors
 
 import (
 	"context"
+	"errors"
 
 	"github.com/zeromicro/go-zero/core/breaker"
 	"github.com/zeromicro/go-zero/zrpc/internal/codes"
 	"google.golang.org/grpc"
+	gcodes "google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
 )
 
 // StreamBreakerInterceptor is an interceptor that acts as a circuit breaker.
@@ -26,6 +29,9 @@ func UnaryBreakerInterceptor(ctx context.Context, req any, info *grpc.UnaryServe
 		resp, err = handler(ctx, req)
 		return err
 	}, codes.Acceptable)
+	if errors.Is(err, breaker.ErrServiceUnavailable) {
+		err = status.Error(gcodes.Unavailable, err.Error())
+	}
 
 	return resp, err
 }

+ 5 - 1
zrpc/internal/serverinterceptors/sheddinginterceptor.go

@@ -2,11 +2,14 @@ package serverinterceptors
 
 import (
 	"context"
+	"errors"
 	"sync"
 
 	"github.com/zeromicro/go-zero/core/load"
 	"github.com/zeromicro/go-zero/core/stat"
 	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
 )
 
 const serviceType = "rpc"
@@ -28,11 +31,12 @@ func UnarySheddingInterceptor(shedder load.Shedder, metrics *stat.Metrics) grpc.
 		if err != nil {
 			metrics.AddDrop()
 			sheddingStat.IncrementDrop()
+			err = status.Error(codes.ResourceExhausted, err.Error())
 			return
 		}
 
 		defer func() {
-			if err == context.DeadlineExceeded {
+			if errors.Is(err, context.DeadlineExceeded) {
 				promise.Fail()
 			} else {
 				sheddingStat.IncrementPass()

+ 3 - 1
zrpc/internal/serverinterceptors/sheddinginterceptor_test.go

@@ -8,6 +8,8 @@ import (
 	"github.com/zeromicro/go-zero/core/load"
 	"github.com/zeromicro/go-zero/core/stat"
 	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
 )
 
 func TestUnarySheddingInterceptor(t *testing.T) {
@@ -33,7 +35,7 @@ func TestUnarySheddingInterceptor(t *testing.T) {
 			name:      "reject",
 			allow:     false,
 			handleErr: nil,
-			expect:    load.ErrServiceOverloaded,
+			expect:    status.Error(codes.ResourceExhausted, load.ErrServiceOverloaded.Error()),
 		},
 	}
 

+ 3 - 2
zrpc/internal/serverinterceptors/timeoutinterceptor.go

@@ -2,6 +2,7 @@ package serverinterceptors
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"runtime/debug"
 	"strings"
@@ -49,9 +50,9 @@ func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor
 			return resp, err
 		case <-ctx.Done():
 			err := ctx.Err()
-			if err == context.Canceled {
+			if errors.Is(err, context.Canceled) {
 				err = status.Error(codes.Canceled, err.Error())
-			} else if err == context.DeadlineExceeded {
+			} else if errors.Is(err, context.DeadlineExceeded) {
 				err = status.Error(codes.DeadlineExceeded, err.Error())
 			}
 			return nil, err