Browse Source

feat: support ctx in `Cache` (#1518)

* feature: support ctx in `Cache`

Signed-off-by: chenquan <chenquan.dev@foxmail.com>

* fix: `errors.Is` instead of `=`

Signed-off-by: chenquan <chenquan.dev@foxmail.com>
chenquan 3 năm trước cách đây
mục cha
commit
e8c307e4dc
3 tập tin đã thay đổi với 169 bổ sung36 xóa
  1. 74 8
      core/stores/cache/cache.go
  2. 30 2
      core/stores/cache/cache_test.go
  3. 65 26
      core/stores/cache/cachenode.go

+ 74 - 8
core/stores/cache/cache.go

@@ -1,6 +1,8 @@
 package cache
 
 import (
+	"context"
+	"errors"
 	"fmt"
 	"log"
 	"time"
@@ -13,13 +15,36 @@ import (
 type (
 	// Cache interface is used to define the cache implementation.
 	Cache interface {
+		// Del deletes cached values with keys.
 		Del(keys ...string) error
+		// DelCtx deletes cached values with keys.
+		DelCtx(ctx context.Context, keys ...string) error
+		// Get gets the cache with key and fills into v.
 		Get(key string, v interface{}) error
+		// GetCtx gets the cache with key and fills into v.
+		GetCtx(ctx context.Context, key string, v interface{}) error
+		// IsNotFound checks if the given error is the defined errNotFound.
 		IsNotFound(err error) bool
+		// Set sets the cache with key and v, using c.expiry.
 		Set(key string, v interface{}) error
+		// SetCtx sets the cache with key and v, using c.expiry.
+		SetCtx(ctx context.Context, key string, v interface{}) error
+		// SetWithExpire sets the cache with key and v, using given expire.
 		SetWithExpire(key string, v interface{}, expire time.Duration) error
+		// SetWithExpireCtx sets the cache with key and v, using given expire.
+		SetWithExpireCtx(ctx context.Context, key string, v interface{}, expire time.Duration) error
+		// Take takes the result from cache first, if not found,
+		// query from DB and set cache using c.expiry, then return the result.
 		Take(v interface{}, key string, query func(v interface{}) error) error
+		// TakeCtx takes the result from cache first, if not found,
+		// query from DB and set cache using c.expiry, then return the result.
+		TakeCtx(ctx context.Context, v interface{}, key string, query func(v interface{}) error) error
+		// TakeWithExpire takes the result from cache first, if not found,
+		// query from DB and set cache using given expire, then return the result.
 		TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error
+		// TakeWithExpireCtx takes the result from cache first, if not found,
+		// query from DB and set cache using given expire, then return the result.
+		TakeWithExpireCtx(ctx context.Context, v interface{}, key string, query func(v interface{}, expire time.Duration) error) error
 	}
 
 	cacheCluster struct {
@@ -51,7 +76,13 @@ func New(c ClusterConf, barrier syncx.SingleFlight, st *Stat, errNotFound error,
 	}
 }
 
+// Del deletes cached values with keys.
 func (cc cacheCluster) Del(keys ...string) error {
+	return cc.DelCtx(context.Background(), keys...)
+}
+
+// DelCtx deletes cached values with keys.
+func (cc cacheCluster) DelCtx(ctx context.Context, keys ...string) error {
 	switch len(keys) {
 	case 0:
 		return nil
@@ -62,7 +93,7 @@ func (cc cacheCluster) Del(keys ...string) error {
 			return cc.errNotFound
 		}
 
-		return c.(Cache).Del(key)
+		return c.(Cache).DelCtx(ctx, key)
 	default:
 		var be errorx.BatchError
 		nodes := make(map[interface{}][]string)
@@ -76,7 +107,7 @@ func (cc cacheCluster) Del(keys ...string) error {
 			nodes[c] = append(nodes[c], key)
 		}
 		for c, ks := range nodes {
-			if err := c.(Cache).Del(ks...); err != nil {
+			if err := c.(Cache).DelCtx(ctx, ks...); err != nil {
 				be.Add(err)
 			}
 		}
@@ -85,52 +116,87 @@ func (cc cacheCluster) Del(keys ...string) error {
 	}
 }
 
+// Get gets the cache with key and fills into v.
 func (cc cacheCluster) Get(key string, v interface{}) error {
+	return cc.GetCtx(context.Background(), key, v)
+}
+
+// GetCtx gets the cache with key and fills into v.
+func (cc cacheCluster) GetCtx(ctx context.Context, key string, v interface{}) error {
 	c, ok := cc.dispatcher.Get(key)
 	if !ok {
 		return cc.errNotFound
 	}
 
-	return c.(Cache).Get(key, v)
+	return c.(Cache).GetCtx(ctx, key, v)
 }
 
+// IsNotFound checks if the given error is the defined errNotFound.
 func (cc cacheCluster) IsNotFound(err error) bool {
-	return err == cc.errNotFound
+	return errors.Is(err, cc.errNotFound)
 }
 
+// Set sets the cache with key and v, using c.expiry.
 func (cc cacheCluster) Set(key string, v interface{}) error {
+	return cc.SetCtx(context.Background(), key, v)
+}
+
+// SetCtx sets the cache with key and v, using c.expiry.
+func (cc cacheCluster) SetCtx(ctx context.Context, key string, v interface{}) error {
 	c, ok := cc.dispatcher.Get(key)
 	if !ok {
 		return cc.errNotFound
 	}
 
-	return c.(Cache).Set(key, v)
+	return c.(Cache).SetCtx(ctx, key, v)
 }
 
+// SetWithExpire sets the cache with key and v, using given expire.
 func (cc cacheCluster) SetWithExpire(key string, v interface{}, expire time.Duration) error {
+	return cc.SetWithExpireCtx(context.Background(), key, v, expire)
+}
+
+// SetWithExpireCtx sets the cache with key and v, using given expire.
+func (cc cacheCluster) SetWithExpireCtx(ctx context.Context, key string, v interface{}, expire time.Duration) error {
 	c, ok := cc.dispatcher.Get(key)
 	if !ok {
 		return cc.errNotFound
 	}
 
-	return c.(Cache).SetWithExpire(key, v, expire)
+	return c.(Cache).SetWithExpireCtx(ctx, key, v, expire)
 }
 
+// Take takes the result from cache first, if not found,
+// query from DB and set cache using c.expiry, then return the result.
 func (cc cacheCluster) Take(v interface{}, key string, query func(v interface{}) error) error {
+	return cc.TakeCtx(context.Background(), v, key, query)
+}
+
+// TakeCtx takes the result from cache first, if not found,
+// query from DB and set cache using c.expiry, then return the result.
+func (cc cacheCluster) TakeCtx(ctx context.Context, v interface{}, key string, query func(v interface{}) error) error {
 	c, ok := cc.dispatcher.Get(key)
 	if !ok {
 		return cc.errNotFound
 	}
 
-	return c.(Cache).Take(v, key, query)
+	return c.(Cache).TakeCtx(ctx, v, key, query)
 }
 
+// TakeWithExpire takes the result from cache first, if not found,
+// query from DB and set cache using given expire, then return the result.
 func (cc cacheCluster) TakeWithExpire(v interface{}, key string,
 	query func(v interface{}, expire time.Duration) error) error {
+	return cc.TakeWithExpireCtx(context.Background(), v, key, query)
+}
+
+// TakeWithExpireCtx takes the result from cache first, if not found,
+// query from DB and set cache using given expire, then return the result.
+func (cc cacheCluster) TakeWithExpireCtx(ctx context.Context, v interface{}, key string, query func(v interface{}, expire time.Duration) error) error {
 	c, ok := cc.dispatcher.Get(key)
 	if !ok {
 		return cc.errNotFound
 	}
 
-	return c.(Cache).TakeWithExpire(v, key, query)
+	return c.(Cache).TakeWithExpireCtx(ctx, v, key, query)
 }

+ 30 - 2
core/stores/cache/cache_test.go

@@ -1,7 +1,9 @@
 package cache
 
 import (
+	"context"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"math"
 	"strconv"
@@ -16,6 +18,8 @@ import (
 	"github.com/zeromicro/go-zero/core/syncx"
 )
 
+var _ Cache = (*mockedNode)(nil)
+
 type mockedNode struct {
 	vals        map[string][]byte
 	errNotFound error
@@ -45,7 +49,7 @@ func (mc *mockedNode) Get(key string, v interface{}) error {
 }
 
 func (mc *mockedNode) IsNotFound(err error) bool {
-	return err == mc.errNotFound
+	return errors.Is(err, mc.errNotFound)
 }
 
 func (mc *mockedNode) Set(key string, v interface{}) error {
@@ -58,7 +62,7 @@ func (mc *mockedNode) Set(key string, v interface{}) error {
 	return nil
 }
 
-func (mc *mockedNode) SetWithExpire(key string, v interface{}, expire time.Duration) error {
+func (mc *mockedNode) SetWithExpire(key string, v interface{}, _ time.Duration) error {
 	return mc.Set(key, v)
 }
 
@@ -80,6 +84,30 @@ func (mc *mockedNode) TakeWithExpire(v interface{}, key string, query func(v int
 	})
 }
 
+func (mc *mockedNode) DelCtx(_ context.Context, keys ...string) error {
+	return mc.Del(keys...)
+}
+
+func (mc *mockedNode) GetCtx(_ context.Context, key string, v interface{}) error {
+	return mc.Get(key, v)
+}
+
+func (mc *mockedNode) SetCtx(_ context.Context, key string, v interface{}) error {
+	return mc.Set(key, v)
+}
+
+func (mc *mockedNode) SetWithExpireCtx(_ context.Context, key string, v interface{}, expire time.Duration) error {
+	return mc.SetWithExpire(key, v, expire)
+}
+
+func (mc *mockedNode) TakeCtx(_ context.Context, v interface{}, key string, query func(v interface{}) error) error {
+	return mc.Take(v, key, query)
+}
+
+func (mc *mockedNode) TakeWithExpireCtx(_ context.Context, v interface{}, key string, query func(v interface{}, expire time.Duration) error) error {
+	return mc.TakeWithExpire(v, key, query)
+}
+
 func TestCache_SetDel(t *testing.T) {
 	const total = 1000
 	r1, clean1, err := redistest.CreateRedis()

+ 65 - 26
core/stores/cache/cachenode.go

@@ -1,6 +1,7 @@
 package cache
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"math/rand"
@@ -61,20 +62,27 @@ func NewNode(rds *redis.Redis, barrier syncx.SingleFlight, st *Stat,
 
 // Del deletes cached values with keys.
 func (c cacheNode) Del(keys ...string) error {
+	return c.DelCtx(context.Background(), keys...)
+}
+
+// DelCtx deletes cached values with keys.
+func (c cacheNode) DelCtx(ctx context.Context, keys ...string) error {
 	if len(keys) == 0 {
 		return nil
 	}
 
+	logger := logx.WithContext(ctx)
+
 	if len(keys) > 1 && c.rds.Type == redis.ClusterType {
 		for _, key := range keys {
-			if _, err := c.rds.Del(key); err != nil {
-				logx.Errorf("failed to clear cache with key: %q, error: %v", key, err)
+			if _, err := c.rds.DelCtx(ctx, key); err != nil {
+				logger.Errorf("failed to clear cache with key: %q, error: %v", key, err)
 				c.asyncRetryDelCache(key)
 			}
 		}
 	} else {
-		if _, err := c.rds.Del(keys...); err != nil {
-			logx.Errorf("failed to clear cache with keys: %q, error: %v", formatKeys(keys), err)
+		if _, err := c.rds.DelCtx(ctx, keys...); err != nil {
+			logger.Errorf("failed to clear cache with keys: %q, error: %v", formatKeys(keys), err)
 			c.asyncRetryDelCache(keys...)
 		}
 	}
@@ -84,7 +92,12 @@ func (c cacheNode) Del(keys ...string) error {
 
 // Get gets the cache with key and fills into v.
 func (c cacheNode) Get(key string, v interface{}) error {
-	err := c.doGetCache(key, v)
+	return c.GetCtx(context.Background(), key, v)
+}
+
+// GetCtx gets the cache with key and fills into v.
+func (c cacheNode) GetCtx(ctx context.Context, key string, v interface{}) error {
+	err := c.doGetCache(ctx, key, v)
 	if err == errPlaceholder {
 		return c.errNotFound
 	}
@@ -94,22 +107,32 @@ func (c cacheNode) Get(key string, v interface{}) error {
 
 // IsNotFound checks if the given error is the defined errNotFound.
 func (c cacheNode) IsNotFound(err error) bool {
-	return err == c.errNotFound
+	return errors.Is(err, c.errNotFound)
 }
 
 // Set sets the cache with key and v, using c.expiry.
 func (c cacheNode) Set(key string, v interface{}) error {
-	return c.SetWithExpire(key, v, c.aroundDuration(c.expiry))
+	return c.SetCtx(context.Background(), key, v)
+}
+
+// SetCtx sets the cache with key and v, using c.expiry.
+func (c cacheNode) SetCtx(ctx context.Context, key string, v interface{}) error {
+	return c.SetWithExpireCtx(ctx, key, v, c.aroundDuration(c.expiry))
 }
 
 // SetWithExpire sets the cache with key and v, using given expire.
 func (c cacheNode) SetWithExpire(key string, v interface{}, expire time.Duration) error {
+	return c.SetWithExpireCtx(context.Background(), key, v, expire)
+}
+
+// SetWithExpireCtx sets the cache with key and v, using given expire.
+func (c cacheNode) SetWithExpireCtx(ctx context.Context, key string, v interface{}, expire time.Duration) error {
 	data, err := jsonx.Marshal(v)
 	if err != nil {
 		return err
 	}
 
-	return c.rds.Setex(key, string(data), int(expire.Seconds()))
+	return c.rds.SetexCtx(ctx, key, string(data), int(expire.Seconds()))
 }
 
 // String returns a string that represents the cacheNode.
@@ -120,8 +143,14 @@ func (c cacheNode) String() string {
 // Take takes the result from cache first, if not found,
 // query from DB and set cache using c.expiry, then return the result.
 func (c cacheNode) Take(v interface{}, key string, query func(v interface{}) error) error {
-	return c.doTake(v, key, query, func(v interface{}) error {
-		return c.Set(key, v)
+	return c.TakeCtx(context.Background(), v, key, query)
+}
+
+// TakeCtx takes the result from cache first, if not found,
+// query from DB and set cache using c.expiry, then return the result.
+func (c cacheNode) TakeCtx(ctx context.Context, v interface{}, key string, query func(v interface{}) error) error {
+	return c.doTake(ctx, v, key, query, func(v interface{}) error {
+		return c.SetCtx(ctx, key, v)
 	})
 }
 
@@ -129,11 +158,17 @@ func (c cacheNode) Take(v interface{}, key string, query func(v interface{}) err
 // query from DB and set cache using given expire, then return the result.
 func (c cacheNode) TakeWithExpire(v interface{}, key string, query func(v interface{},
 	expire time.Duration) error) error {
+	return c.TakeWithExpireCtx(context.Background(), v, key, query)
+}
+
+// TakeWithExpireCtx takes the result from cache first, if not found,
+// query from DB and set cache using given expire, then return the result.
+func (c cacheNode) TakeWithExpireCtx(ctx context.Context, v interface{}, key string, query func(v interface{}, expire time.Duration) error) error {
 	expire := c.aroundDuration(c.expiry)
-	return c.doTake(v, key, func(v interface{}) error {
+	return c.doTake(ctx, v, key, func(v interface{}) error {
 		return query(v, expire)
 	}, func(v interface{}) error {
-		return c.SetWithExpire(key, v, expire)
+		return c.SetWithExpireCtx(ctx, key, v, expire)
 	})
 }
 
@@ -148,9 +183,9 @@ func (c cacheNode) asyncRetryDelCache(keys ...string) {
 	}, keys...)
 }
 
-func (c cacheNode) doGetCache(key string, v interface{}) error {
+func (c cacheNode) doGetCache(ctx context.Context, key string, v interface{}) error {
 	c.stat.IncrementTotal()
-	data, err := c.rds.Get(key)
+	data, err := c.rds.GetCtx(ctx, key)
 	if err != nil {
 		c.stat.IncrementMiss()
 		return err
@@ -166,13 +201,15 @@ func (c cacheNode) doGetCache(key string, v interface{}) error {
 		return errPlaceholder
 	}
 
-	return c.processCache(key, data, v)
+	return c.processCache(ctx, key, data, v)
 }
 
-func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) error,
+func (c cacheNode) doTake(ctx context.Context, v interface{}, key string, query func(v interface{}) error,
 	cacheVal func(v interface{}) error) error {
+	logger := logx.WithContext(ctx)
+
 	val, fresh, err := c.barrier.DoEx(key, func() (interface{}, error) {
-		if err := c.doGetCache(key, v); err != nil {
+		if err := c.doGetCache(ctx, key, v); err != nil {
 			if err == errPlaceholder {
 				return nil, c.errNotFound
 			} else if err != c.errNotFound {
@@ -183,8 +220,8 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e
 			}
 
 			if err = query(v); err == c.errNotFound {
-				if err = c.setCacheWithNotFound(key); err != nil {
-					logx.Error(err)
+				if err = c.setCacheWithNotFound(ctx, key); err != nil {
+					logger.Error(err)
 				}
 
 				return nil, c.errNotFound
@@ -194,7 +231,7 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e
 			}
 
 			if err = cacheVal(v); err != nil {
-				logx.Error(err)
+				logger.Error(err)
 			}
 		}
 
@@ -214,7 +251,9 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e
 	return jsonx.Unmarshal(val.([]byte), v)
 }
 
-func (c cacheNode) processCache(key, data string, v interface{}) error {
+func (c cacheNode) processCache(ctx context.Context, key, data string, v interface{}) error {
+	logger := logx.WithContext(ctx)
+
 	err := jsonx.Unmarshal([]byte(data), v)
 	if err == nil {
 		return nil
@@ -222,10 +261,10 @@ func (c cacheNode) processCache(key, data string, v interface{}) error {
 
 	report := fmt.Sprintf("unmarshal cache, node: %s, key: %s, value: %s, error: %v",
 		c.rds.Addr, key, data, err)
-	logx.Error(report)
+	logger.Error(report)
 	stat.Report(report)
-	if _, e := c.rds.Del(key); e != nil {
-		logx.Errorf("delete invalid cache, node: %s, key: %s, value: %s, error: %v",
+	if _, e := c.rds.DelCtx(ctx, key); e != nil {
+		logger.Errorf("delete invalid cache, node: %s, key: %s, value: %s, error: %v",
 			c.rds.Addr, key, data, e)
 	}
 
@@ -233,6 +272,6 @@ func (c cacheNode) processCache(key, data string, v interface{}) error {
 	return c.errNotFound
 }
 
-func (c cacheNode) setCacheWithNotFound(key string) error {
-	return c.rds.Setex(key, notFoundPlaceholder, int(c.aroundDuration(c.notFoundExpiry).Seconds()))
+func (c cacheNode) setCacheWithNotFound(ctx context.Context, key string) error {
+	return c.rds.SetexCtx(ctx, key, notFoundPlaceholder, int(c.aroundDuration(c.notFoundExpiry).Seconds()))
 }