cachedsql_test.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823
  1. package sqlc
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "log"
  10. "os"
  11. "runtime"
  12. "sync"
  13. "sync/atomic"
  14. "testing"
  15. "time"
  16. "github.com/DATA-DOG/go-sqlmock"
  17. "github.com/alicebob/miniredis/v2"
  18. "github.com/stretchr/testify/assert"
  19. "github.com/wuntsong-org/go-zero-plus/core/fx"
  20. "github.com/wuntsong-org/go-zero-plus/core/logx"
  21. "github.com/wuntsong-org/go-zero-plus/core/stat"
  22. "github.com/wuntsong-org/go-zero-plus/core/stores/cache"
  23. "github.com/wuntsong-org/go-zero-plus/core/stores/dbtest"
  24. "github.com/wuntsong-org/go-zero-plus/core/stores/redis"
  25. "github.com/wuntsong-org/go-zero-plus/core/stores/redis/redistest"
  26. "github.com/wuntsong-org/go-zero-plus/core/stores/sqlx"
  27. "github.com/wuntsong-org/go-zero-plus/core/syncx"
  28. )
  29. func init() {
  30. logx.Disable()
  31. stat.SetReporter(nil)
  32. }
  33. func TestCachedConn_GetCache(t *testing.T) {
  34. resetStats()
  35. r := redistest.CreateRedis(t)
  36. c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
  37. var value string
  38. err := c.GetCache("any", &value)
  39. assert.Equal(t, ErrNotFound, err)
  40. _ = r.Set("any", `"value"`)
  41. err = c.GetCache("any", &value)
  42. assert.Nil(t, err)
  43. assert.Equal(t, "value", value)
  44. }
  45. func TestStat(t *testing.T) {
  46. resetStats()
  47. r := redistest.CreateRedis(t)
  48. c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
  49. for i := 0; i < 10; i++ {
  50. var str string
  51. err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v any) error {
  52. *v.(*string) = "zero"
  53. return nil
  54. })
  55. if err != nil {
  56. t.Error(err)
  57. }
  58. }
  59. assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
  60. assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
  61. }
  62. func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) {
  63. resetStats()
  64. r := redistest.CreateRedis(t)
  65. c := NewConn(dummySqlConn{}, cache.CacheConf{
  66. {
  67. RedisConf: redis.RedisConf{
  68. Host: r.Addr,
  69. Type: redis.NodeType,
  70. },
  71. Weight: 100,
  72. },
  73. }, cache.WithExpiry(time.Second*10))
  74. var str string
  75. err := c.QueryRowIndex(&str, "index", func(s any) string {
  76. return fmt.Sprintf("%s/1234", s)
  77. }, func(conn sqlx.SqlConn, v any) (any, error) {
  78. *v.(*string) = "zero"
  79. return "primary", errors.New("foo")
  80. }, func(conn sqlx.SqlConn, v, pri any) error {
  81. assert.Equal(t, "primary", pri)
  82. *v.(*string) = "xin"
  83. return nil
  84. })
  85. assert.NotNil(t, err)
  86. err = c.QueryRowIndex(&str, "index", func(s any) string {
  87. return fmt.Sprintf("%s/1234", s)
  88. }, func(conn sqlx.SqlConn, v any) (any, error) {
  89. *v.(*string) = "zero"
  90. return "primary", nil
  91. }, func(conn sqlx.SqlConn, v, pri any) error {
  92. assert.Equal(t, "primary", pri)
  93. *v.(*string) = "xin"
  94. return nil
  95. })
  96. assert.Nil(t, err)
  97. assert.Equal(t, "zero", str)
  98. val, err := r.Get("index")
  99. assert.Nil(t, err)
  100. assert.Equal(t, `"primary"`, val)
  101. val, err = r.Get("primary/1234")
  102. assert.Nil(t, err)
  103. assert.Equal(t, `"zero"`, val)
  104. }
  105. func TestCachedConn_QueryRowIndex_HasCache(t *testing.T) {
  106. resetStats()
  107. r := redistest.CreateRedis(t)
  108. c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
  109. cache.WithNotFoundExpiry(time.Second))
  110. var str string
  111. r.Set("index", `"primary"`)
  112. err := c.QueryRowIndex(&str, "index", func(s any) string {
  113. return fmt.Sprintf("%s/1234", s)
  114. }, func(conn sqlx.SqlConn, v any) (any, error) {
  115. assert.Fail(t, "should not go here")
  116. return "primary", nil
  117. }, func(conn sqlx.SqlConn, v, primary any) error {
  118. *v.(*string) = "xin"
  119. assert.Equal(t, "primary", primary)
  120. return nil
  121. })
  122. assert.Nil(t, err)
  123. assert.Equal(t, "xin", str)
  124. val, err := r.Get("index")
  125. assert.Nil(t, err)
  126. assert.Equal(t, `"primary"`, val)
  127. val, err = r.Get("primary/1234")
  128. assert.Nil(t, err)
  129. assert.Equal(t, `"xin"`, val)
  130. }
  131. func TestCachedConn_QueryRowIndex_HasCache_IntPrimary(t *testing.T) {
  132. const (
  133. primaryInt8 int8 = 100
  134. primaryInt16 int16 = 10000
  135. primaryInt32 int32 = 10000000
  136. primaryInt64 int64 = 10000000
  137. primaryUint8 uint8 = 100
  138. primaryUint16 uint16 = 10000
  139. primaryUint32 uint32 = 10000000
  140. primaryUint64 uint64 = 10000000
  141. )
  142. tests := []struct {
  143. name string
  144. primary any
  145. primaryCache string
  146. }{
  147. {
  148. name: "int8 primary",
  149. primary: primaryInt8,
  150. primaryCache: fmt.Sprint(primaryInt8),
  151. },
  152. {
  153. name: "int16 primary",
  154. primary: primaryInt16,
  155. primaryCache: fmt.Sprint(primaryInt16),
  156. },
  157. {
  158. name: "int32 primary",
  159. primary: primaryInt32,
  160. primaryCache: fmt.Sprint(primaryInt32),
  161. },
  162. {
  163. name: "int64 primary",
  164. primary: primaryInt64,
  165. primaryCache: fmt.Sprint(primaryInt64),
  166. },
  167. {
  168. name: "uint8 primary",
  169. primary: primaryUint8,
  170. primaryCache: fmt.Sprint(primaryUint8),
  171. },
  172. {
  173. name: "uint16 primary",
  174. primary: primaryUint16,
  175. primaryCache: fmt.Sprint(primaryUint16),
  176. },
  177. {
  178. name: "uint32 primary",
  179. primary: primaryUint32,
  180. primaryCache: fmt.Sprint(primaryUint32),
  181. },
  182. {
  183. name: "uint64 primary",
  184. primary: primaryUint64,
  185. primaryCache: fmt.Sprint(primaryUint64),
  186. },
  187. }
  188. for _, test := range tests {
  189. t.Run(test.name, func(t *testing.T) {
  190. resetStats()
  191. r := redistest.CreateRedis(t)
  192. c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
  193. cache.WithNotFoundExpiry(time.Second))
  194. var str string
  195. r.Set("index", test.primaryCache)
  196. err := c.QueryRowIndex(&str, "index", func(s any) string {
  197. return fmt.Sprintf("%v/1234", s)
  198. }, func(conn sqlx.SqlConn, v any) (any, error) {
  199. assert.Fail(t, "should not go here")
  200. return test.primary, nil
  201. }, func(conn sqlx.SqlConn, v, primary any) error {
  202. *v.(*string) = "xin"
  203. assert.Equal(t, primary, primary)
  204. return nil
  205. })
  206. assert.Nil(t, err)
  207. assert.Equal(t, "xin", str)
  208. val, err := r.Get("index")
  209. assert.Nil(t, err)
  210. assert.Equal(t, test.primaryCache, val)
  211. val, err = r.Get(test.primaryCache + "/1234")
  212. assert.Nil(t, err)
  213. assert.Equal(t, `"xin"`, val)
  214. })
  215. }
  216. }
  217. func TestCachedConn_QueryRowIndex_HasWrongCache(t *testing.T) {
  218. caches := map[string]string{
  219. "index": "primary",
  220. "primary/1234": "xin",
  221. }
  222. for k, v := range caches {
  223. t.Run(k+"/"+v, func(t *testing.T) {
  224. resetStats()
  225. r := redistest.CreateRedis(t)
  226. c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
  227. cache.WithNotFoundExpiry(time.Second))
  228. var str string
  229. r.Set(k, v)
  230. err := c.QueryRowIndex(&str, "index", func(s any) string {
  231. return fmt.Sprintf("%s/1234", s)
  232. }, func(conn sqlx.SqlConn, v any) (any, error) {
  233. *v.(*string) = "xin"
  234. return "primary", nil
  235. }, func(conn sqlx.SqlConn, v, primary any) error {
  236. *v.(*string) = "xin"
  237. assert.Equal(t, "primary", primary)
  238. return nil
  239. })
  240. assert.Nil(t, err)
  241. assert.Equal(t, "xin", str)
  242. val, err := r.Get("index")
  243. assert.Nil(t, err)
  244. assert.Equal(t, `"primary"`, val)
  245. val, err = r.Get("primary/1234")
  246. assert.Nil(t, err)
  247. assert.Equal(t, `"xin"`, val)
  248. })
  249. }
  250. }
  251. func TestStatCacheFails(t *testing.T) {
  252. resetStats()
  253. log.SetOutput(io.Discard)
  254. defer log.SetOutput(os.Stdout)
  255. r := redis.New("localhost:59999")
  256. c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
  257. for i := 0; i < 20; i++ {
  258. var str string
  259. err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v any) error {
  260. return errors.New("db failed")
  261. })
  262. assert.NotNil(t, err)
  263. }
  264. assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
  265. assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
  266. assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Miss))
  267. assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.DbFails))
  268. }
  269. func TestStatDbFails(t *testing.T) {
  270. resetStats()
  271. r := redistest.CreateRedis(t)
  272. c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
  273. for i := 0; i < 20; i++ {
  274. var str string
  275. err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v any) error {
  276. return errors.New("db failed")
  277. })
  278. assert.NotNil(t, err)
  279. }
  280. assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
  281. assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
  282. assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.DbFails))
  283. }
  284. func TestStatFromMemory(t *testing.T) {
  285. resetStats()
  286. r := redistest.CreateRedis(t)
  287. c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
  288. var all sync.WaitGroup
  289. var wait sync.WaitGroup
  290. all.Add(10)
  291. wait.Add(4)
  292. go func() {
  293. var str string
  294. err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v any) error {
  295. *v.(*string) = "zero"
  296. return nil
  297. })
  298. if err != nil {
  299. t.Error(err)
  300. }
  301. wait.Wait()
  302. runtime.Gosched()
  303. all.Done()
  304. }()
  305. for i := 0; i < 4; i++ {
  306. go func() {
  307. var str string
  308. wait.Done()
  309. err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v any) error {
  310. *v.(*string) = "zero"
  311. return nil
  312. })
  313. if err != nil {
  314. t.Error(err)
  315. }
  316. all.Done()
  317. }()
  318. }
  319. for i := 0; i < 5; i++ {
  320. go func() {
  321. var str string
  322. err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v any) error {
  323. *v.(*string) = "zero"
  324. return nil
  325. })
  326. if err != nil {
  327. t.Error(err)
  328. }
  329. all.Done()
  330. }()
  331. }
  332. all.Wait()
  333. assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
  334. assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
  335. }
  336. func TestCachedConn_DelCache(t *testing.T) {
  337. r := redistest.CreateRedis(t)
  338. const (
  339. key = "user"
  340. value = "any"
  341. )
  342. assert.NoError(t, r.Set(key, value))
  343. c := NewNodeConn(&trackedConn{}, r, cache.WithExpiry(time.Second*30))
  344. err := c.DelCache(key)
  345. assert.Nil(t, err)
  346. val, err := r.Get(key)
  347. assert.Nil(t, err)
  348. assert.Empty(t, val)
  349. }
  350. func TestCachedConnQueryRow(t *testing.T) {
  351. r := redistest.CreateRedis(t)
  352. const (
  353. key = "user"
  354. value = "any"
  355. )
  356. var conn trackedConn
  357. var user string
  358. var ran bool
  359. c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
  360. err := c.QueryRow(&user, key, func(conn sqlx.SqlConn, v any) error {
  361. ran = true
  362. user = value
  363. return nil
  364. })
  365. assert.Nil(t, err)
  366. actualValue, err := r.Get(key)
  367. assert.Nil(t, err)
  368. var actual string
  369. assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual))
  370. assert.Equal(t, value, actual)
  371. assert.Equal(t, value, user)
  372. assert.True(t, ran)
  373. }
  374. func TestCachedConnQueryRowFromCache(t *testing.T) {
  375. r := redistest.CreateRedis(t)
  376. const (
  377. key = "user"
  378. value = "any"
  379. )
  380. var conn trackedConn
  381. var user string
  382. var ran bool
  383. c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
  384. assert.Nil(t, c.SetCache(key, value))
  385. err := c.QueryRow(&user, key, func(conn sqlx.SqlConn, v any) error {
  386. ran = true
  387. user = value
  388. return nil
  389. })
  390. assert.Nil(t, err)
  391. actualValue, err := r.Get(key)
  392. assert.Nil(t, err)
  393. var actual string
  394. assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual))
  395. assert.Equal(t, value, actual)
  396. assert.Equal(t, value, user)
  397. assert.False(t, ran)
  398. }
  399. func TestQueryRowNotFound(t *testing.T) {
  400. r := redistest.CreateRedis(t)
  401. const key = "user"
  402. var conn trackedConn
  403. var user string
  404. var ran int
  405. c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
  406. for i := 0; i < 20; i++ {
  407. err := c.QueryRow(&user, key, func(conn sqlx.SqlConn, v any) error {
  408. ran++
  409. return sql.ErrNoRows
  410. })
  411. assert.Exactly(t, sqlx.ErrNotFound, err)
  412. }
  413. assert.Equal(t, 1, ran)
  414. }
  415. func TestCachedConnExec(t *testing.T) {
  416. r := redistest.CreateRedis(t)
  417. var conn trackedConn
  418. c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
  419. _, err := c.ExecNoCache("delete from user_table where id='kevin'")
  420. assert.Nil(t, err)
  421. assert.True(t, conn.execValue)
  422. }
  423. func TestCachedConnExecDropCache(t *testing.T) {
  424. t.Run("drop cache", func(t *testing.T) {
  425. r, err := miniredis.Run()
  426. assert.Nil(t, err)
  427. defer fx.DoWithTimeout(func() error {
  428. r.Close()
  429. return nil
  430. }, time.Second)
  431. const (
  432. key = "user"
  433. value = "any"
  434. )
  435. var conn trackedConn
  436. c := NewNodeConn(&conn, redis.New(r.Addr()), cache.WithExpiry(time.Second*30))
  437. assert.Nil(t, c.SetCache(key, value))
  438. _, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
  439. return conn.Exec("delete from user_table where id='kevin'")
  440. }, key)
  441. assert.Nil(t, err)
  442. assert.True(t, conn.execValue)
  443. _, err = r.Get(key)
  444. assert.Exactly(t, miniredis.ErrKeyNotFound, err)
  445. _, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
  446. return nil, errors.New("foo")
  447. }, key)
  448. assert.NotNil(t, err)
  449. })
  450. }
  451. func TestCachedConn_SetCacheWithExpire(t *testing.T) {
  452. r, err := miniredis.Run()
  453. assert.Nil(t, err)
  454. defer fx.DoWithTimeout(func() error {
  455. r.Close()
  456. return nil
  457. }, time.Second)
  458. const (
  459. key = "user"
  460. value = "any"
  461. )
  462. var conn trackedConn
  463. c := NewNodeConn(&conn, redis.New(r.Addr()), cache.WithExpiry(time.Second*30))
  464. assert.Nil(t, c.SetCacheWithExpire(key, value, time.Minute))
  465. val, err := r.Get(key)
  466. if assert.NoError(t, err) {
  467. ttl := r.TTL(key)
  468. assert.True(t, ttl > 0 && ttl <= time.Minute)
  469. assert.Equal(t, fmt.Sprintf("%q", value), val)
  470. }
  471. }
  472. func TestCachedConnExecDropCacheFailed(t *testing.T) {
  473. const key = "user"
  474. var conn trackedConn
  475. r := redis.New("anyredis:8888")
  476. c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
  477. _, err := c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
  478. return conn.Exec("delete from user_table where id='kevin'")
  479. }, key)
  480. // async background clean, retry logic
  481. assert.Nil(t, err)
  482. }
  483. func TestCachedConnQueryRows(t *testing.T) {
  484. r := redistest.CreateRedis(t)
  485. var conn trackedConn
  486. c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
  487. var users []string
  488. err := c.QueryRowsNoCache(&users, "select user from user_table where id='kevin'")
  489. assert.Nil(t, err)
  490. assert.True(t, conn.queryRowsValue)
  491. }
  492. func TestCachedConnTransact(t *testing.T) {
  493. r := redistest.CreateRedis(t)
  494. var conn trackedConn
  495. c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
  496. err := c.Transact(func(session sqlx.Session) error {
  497. return nil
  498. })
  499. assert.Nil(t, err)
  500. assert.True(t, conn.transactValue)
  501. }
  502. func TestQueryRowNoCache(t *testing.T) {
  503. r := redistest.CreateRedis(t)
  504. const (
  505. key = "user"
  506. value = "any"
  507. )
  508. var user string
  509. var ran bool
  510. conn := dummySqlConn{queryRow: func(v any, q string, args ...any) error {
  511. user = value
  512. ran = true
  513. return nil
  514. }}
  515. c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
  516. err := c.QueryRowNoCache(&user, key)
  517. assert.Nil(t, err)
  518. assert.Equal(t, value, user)
  519. assert.True(t, ran)
  520. }
  521. func TestNewConnWithCache(t *testing.T) {
  522. r := redistest.CreateRedis(t)
  523. var conn trackedConn
  524. c := NewConnWithCache(&conn, cache.NewNode(r, singleFlights, stats, sql.ErrNoRows))
  525. _, err := c.ExecNoCache("delete from user_table where id='kevin'")
  526. assert.Nil(t, err)
  527. assert.True(t, conn.execValue)
  528. }
  529. func TestCachedConn_WithSession(t *testing.T) {
  530. dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
  531. mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
  532. r := redistest.CreateRedis(t)
  533. conn := CachedConn{
  534. cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
  535. }
  536. conn = conn.WithSession(sqlx.NewSessionFromTx(tx))
  537. res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
  538. return conn.Exec("any")
  539. }, "foo")
  540. assert.NoError(t, err)
  541. last, err := res.LastInsertId()
  542. assert.NoError(t, err)
  543. assert.Equal(t, int64(2), last)
  544. affected, err := res.RowsAffected()
  545. assert.NoError(t, err)
  546. assert.Equal(t, int64(3), affected)
  547. })
  548. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  549. mock.ExpectBegin()
  550. mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
  551. mock.ExpectCommit()
  552. r := redistest.CreateRedis(t)
  553. conn := CachedConn{
  554. db: sqlx.NewSqlConnFromDB(db),
  555. cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
  556. }
  557. assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
  558. conn = conn.WithSession(session)
  559. res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
  560. return conn.Exec("any")
  561. }, "foo")
  562. assert.NoError(t, err)
  563. last, err := res.LastInsertId()
  564. assert.NoError(t, err)
  565. assert.Equal(t, int64(2), last)
  566. affected, err := res.RowsAffected()
  567. assert.NoError(t, err)
  568. assert.Equal(t, int64(3), affected)
  569. return nil
  570. }))
  571. })
  572. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  573. mock.ExpectBegin()
  574. mock.ExpectExec("any").WillReturnError(errors.New("foo"))
  575. mock.ExpectRollback()
  576. r := redistest.CreateRedis(t)
  577. conn := CachedConn{
  578. db: sqlx.NewSqlConnFromDB(db),
  579. cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
  580. }
  581. assert.Error(t, conn.Transact(func(session sqlx.Session) error {
  582. conn = conn.WithSession(session)
  583. _, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
  584. return conn.Exec("any")
  585. }, "bar")
  586. return err
  587. }))
  588. })
  589. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  590. mock.ExpectBegin()
  591. mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
  592. mock.ExpectCommit()
  593. r := redistest.CreateRedis(t)
  594. conn := CachedConn{
  595. db: sqlx.NewSqlConnFromDB(db),
  596. cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
  597. }
  598. assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
  599. var val string
  600. conn = conn.WithSession(session)
  601. err := conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
  602. return conn.QueryRow(v, "any")
  603. })
  604. assert.Equal(t, "2", val)
  605. return err
  606. }))
  607. val, err := r.Get("foo")
  608. assert.NoError(t, err)
  609. assert.Equal(t, `"2"`, val)
  610. })
  611. dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
  612. mock.ExpectBegin()
  613. mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
  614. mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
  615. mock.ExpectCommit()
  616. r := redistest.CreateRedis(t)
  617. conn := CachedConn{
  618. db: sqlx.NewSqlConnFromDB(db),
  619. cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
  620. }
  621. assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
  622. var val string
  623. conn = conn.WithSession(session)
  624. assert.NoError(t, conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
  625. return conn.QueryRow(v, "any")
  626. }))
  627. assert.Equal(t, "2", val)
  628. _, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
  629. return conn.Exec("any")
  630. }, "foo")
  631. return err
  632. }))
  633. val, err := r.Get("foo")
  634. assert.NoError(t, err)
  635. assert.Empty(t, val)
  636. })
  637. }
  638. func resetStats() {
  639. atomic.StoreUint64(&stats.Total, 0)
  640. atomic.StoreUint64(&stats.Hit, 0)
  641. atomic.StoreUint64(&stats.Miss, 0)
  642. atomic.StoreUint64(&stats.DbFails, 0)
  643. }
  644. type dummySqlConn struct {
  645. queryRow func(any, string, ...any) error
  646. }
  647. func (d dummySqlConn) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
  648. return nil, nil
  649. }
  650. func (d dummySqlConn) PrepareCtx(_ context.Context, _ string) (sqlx.StmtSession, error) {
  651. return nil, nil
  652. }
  653. func (d dummySqlConn) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
  654. return nil
  655. }
  656. func (d dummySqlConn) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
  657. return nil
  658. }
  659. func (d dummySqlConn) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
  660. return nil
  661. }
  662. func (d dummySqlConn) TransactCtx(_ context.Context, _ func(context.Context, sqlx.Session) error) error {
  663. return nil
  664. }
  665. func (d dummySqlConn) Exec(_ string, _ ...any) (sql.Result, error) {
  666. return nil, nil
  667. }
  668. func (d dummySqlConn) Prepare(_ string) (sqlx.StmtSession, error) {
  669. return nil, nil
  670. }
  671. func (d dummySqlConn) QueryRow(v any, query string, args ...any) error {
  672. return d.QueryRowCtx(context.Background(), v, query, args...)
  673. }
  674. func (d dummySqlConn) QueryRowCtx(_ context.Context, v any, query string, args ...any) error {
  675. if d.queryRow != nil {
  676. return d.queryRow(v, query, args...)
  677. }
  678. return nil
  679. }
  680. func (d dummySqlConn) QueryRowPartial(_ any, _ string, _ ...any) error {
  681. return nil
  682. }
  683. func (d dummySqlConn) QueryRows(_ any, _ string, _ ...any) error {
  684. return nil
  685. }
  686. func (d dummySqlConn) QueryRowsPartial(_ any, _ string, _ ...any) error {
  687. return nil
  688. }
  689. func (d dummySqlConn) RawDB() (*sql.DB, error) {
  690. return nil, nil
  691. }
  692. func (d dummySqlConn) Transact(func(session sqlx.Session) error) error {
  693. return nil
  694. }
  695. type trackedConn struct {
  696. dummySqlConn
  697. execValue bool
  698. queryRowsValue bool
  699. transactValue bool
  700. }
  701. func (c *trackedConn) Exec(query string, args ...any) (sql.Result, error) {
  702. return c.ExecCtx(context.Background(), query, args...)
  703. }
  704. func (c *trackedConn) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
  705. c.execValue = true
  706. return c.dummySqlConn.ExecCtx(ctx, query, args...)
  707. }
  708. func (c *trackedConn) QueryRows(v any, query string, args ...any) error {
  709. return c.QueryRowsCtx(context.Background(), v, query, args...)
  710. }
  711. func (c *trackedConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
  712. c.queryRowsValue = true
  713. return c.dummySqlConn.QueryRowsCtx(ctx, v, query, args...)
  714. }
  715. func (c *trackedConn) RawDB() (*sql.DB, error) {
  716. return nil, nil
  717. }
  718. func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
  719. return c.TransactCtx(context.Background(), func(_ context.Context, session sqlx.Session) error {
  720. return fn(session)
  721. })
  722. }
  723. func (c *trackedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
  724. c.transactValue = true
  725. return c.dummySqlConn.TransactCtx(ctx, fn)
  726. }