utils.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. package mapping
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "math"
  7. "reflect"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. "github.com/zeromicro/go-zero/core/stringx"
  12. )
  13. const (
  14. defaultOption = "default"
  15. stringOption = "string"
  16. optionalOption = "optional"
  17. optionsOption = "options"
  18. rangeOption = "range"
  19. optionSeparator = "|"
  20. equalToken = "="
  21. escapeChar = '\\'
  22. leftBracket = '('
  23. rightBracket = ')'
  24. leftSquareBracket = '['
  25. rightSquareBracket = ']'
  26. segmentSeparator = ','
  27. )
  28. var (
  29. errUnsupportedType = errors.New("unsupported type on setting field value")
  30. errNumberRange = errors.New("wrong number range setting")
  31. optionsCache = make(map[string]optionsCacheValue)
  32. cacheLock sync.RWMutex
  33. structRequiredCache = make(map[reflect.Type]requiredCacheValue)
  34. structCacheLock sync.RWMutex
  35. )
  36. type (
  37. optionsCacheValue struct {
  38. key string
  39. options *fieldOptions
  40. err error
  41. }
  42. requiredCacheValue struct {
  43. required bool
  44. err error
  45. }
  46. )
  47. // Deref dereferences a type, if pointer type, returns its element type.
  48. func Deref(t reflect.Type) reflect.Type {
  49. if t.Kind() == reflect.Ptr {
  50. t = t.Elem()
  51. }
  52. return t
  53. }
  54. // Repr returns the string representation of v.
  55. func Repr(v interface{}) string {
  56. if v == nil {
  57. return ""
  58. }
  59. // if func (v *Type) String() string, we can't use Elem()
  60. switch vt := v.(type) {
  61. case fmt.Stringer:
  62. return vt.String()
  63. }
  64. val := reflect.ValueOf(v)
  65. if val.Kind() == reflect.Ptr && !val.IsNil() {
  66. val = val.Elem()
  67. }
  68. return reprOfValue(val)
  69. }
  70. // ValidatePtr validates v if it's a valid pointer.
  71. func ValidatePtr(v *reflect.Value) error {
  72. // sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr,
  73. // panic otherwise
  74. if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() {
  75. return fmt.Errorf("not a valid pointer: %v", v)
  76. }
  77. return nil
  78. }
  79. func convertType(kind reflect.Kind, str string) (interface{}, error) {
  80. switch kind {
  81. case reflect.Bool:
  82. return str == "1" || strings.ToLower(str) == "true", nil
  83. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  84. intValue, err := strconv.ParseInt(str, 10, 64)
  85. if err != nil {
  86. return 0, fmt.Errorf("the value %q cannot parsed as int", str)
  87. }
  88. return intValue, nil
  89. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  90. uintValue, err := strconv.ParseUint(str, 10, 64)
  91. if err != nil {
  92. return 0, fmt.Errorf("the value %q cannot parsed as uint", str)
  93. }
  94. return uintValue, nil
  95. case reflect.Float32, reflect.Float64:
  96. floatValue, err := strconv.ParseFloat(str, 64)
  97. if err != nil {
  98. return 0, fmt.Errorf("the value %q cannot parsed as float", str)
  99. }
  100. return floatValue, nil
  101. case reflect.String:
  102. return str, nil
  103. default:
  104. return nil, errUnsupportedType
  105. }
  106. }
  107. func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fieldOptions, error) {
  108. segments := parseSegments(value)
  109. key := strings.TrimSpace(segments[0])
  110. options := segments[1:]
  111. if len(options) == 0 {
  112. return key, nil, nil
  113. }
  114. var fieldOpts fieldOptions
  115. for _, segment := range options {
  116. option := strings.TrimSpace(segment)
  117. if err := parseOption(&fieldOpts, field.Name, option); err != nil {
  118. return "", nil, err
  119. }
  120. }
  121. return key, &fieldOpts, nil
  122. }
  123. // ensureValue ensures nested members not to be nil.
  124. // If pointer value is nil, set to a new value.
  125. func ensureValue(v reflect.Value) reflect.Value {
  126. for {
  127. if v.Kind() != reflect.Ptr {
  128. break
  129. }
  130. if v.IsNil() {
  131. v.Set(reflect.New(v.Type().Elem()))
  132. }
  133. v = v.Elem()
  134. }
  135. return v
  136. }
  137. func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
  138. numFields := tp.NumField()
  139. for i := 0; i < numFields; i++ {
  140. childField := tp.Field(i)
  141. if usingDifferentKeys(tag, childField) {
  142. return true, nil
  143. }
  144. _, opts, err := parseKeyAndOptions(tag, childField)
  145. if err != nil {
  146. return false, err
  147. }
  148. if opts == nil {
  149. if childField.Type.Kind() != reflect.Struct {
  150. return true, nil
  151. }
  152. if required, err := implicitValueRequiredStruct(tag, childField.Type); err != nil {
  153. return false, err
  154. } else if required {
  155. return true, nil
  156. }
  157. } else if !opts.Optional && len(opts.Default) == 0 {
  158. return true, nil
  159. } else if len(opts.OptionalDep) > 0 && opts.OptionalDep[0] == notSymbol {
  160. return true, nil
  161. }
  162. }
  163. return false, nil
  164. }
  165. func isLeftInclude(b byte) (bool, error) {
  166. switch b {
  167. case '[':
  168. return true, nil
  169. case '(':
  170. return false, nil
  171. default:
  172. return false, errNumberRange
  173. }
  174. }
  175. func isRightInclude(b byte) (bool, error) {
  176. switch b {
  177. case ']':
  178. return true, nil
  179. case ')':
  180. return false, nil
  181. default:
  182. return false, errNumberRange
  183. }
  184. }
  185. func maybeNewValue(field reflect.StructField, value reflect.Value) {
  186. if field.Type.Kind() == reflect.Ptr && value.IsNil() {
  187. value.Set(reflect.New(value.Type().Elem()))
  188. }
  189. }
  190. func parseGroupedSegments(val string) []string {
  191. val = strings.TrimLeftFunc(val, func(r rune) bool {
  192. return r == leftBracket || r == leftSquareBracket
  193. })
  194. val = strings.TrimRightFunc(val, func(r rune) bool {
  195. return r == rightBracket || r == rightSquareBracket
  196. })
  197. return parseSegments(val)
  198. }
  199. // don't modify returned fieldOptions, it's cached and shared among different calls.
  200. func parseKeyAndOptions(tagName string, field reflect.StructField) (string, *fieldOptions, error) {
  201. value := field.Tag.Get(tagName)
  202. if len(value) == 0 {
  203. return field.Name, nil, nil
  204. }
  205. cacheLock.RLock()
  206. cache, ok := optionsCache[value]
  207. cacheLock.RUnlock()
  208. if ok {
  209. return stringx.TakeOne(cache.key, field.Name), cache.options, cache.err
  210. }
  211. key, options, err := doParseKeyAndOptions(field, value)
  212. cacheLock.Lock()
  213. optionsCache[value] = optionsCacheValue{
  214. key: key,
  215. options: options,
  216. err: err,
  217. }
  218. cacheLock.Unlock()
  219. return stringx.TakeOne(key, field.Name), options, err
  220. }
  221. // support below notations:
  222. // [:5] (:5] [:5) (:5)
  223. // [1:] [1:) (1:] (1:)
  224. // [1:5] [1:5) (1:5] (1:5)
  225. func parseNumberRange(str string) (*numberRange, error) {
  226. if len(str) == 0 {
  227. return nil, errNumberRange
  228. }
  229. leftInclude, err := isLeftInclude(str[0])
  230. if err != nil {
  231. return nil, err
  232. }
  233. str = str[1:]
  234. if len(str) == 0 {
  235. return nil, errNumberRange
  236. }
  237. rightInclude, err := isRightInclude(str[len(str)-1])
  238. if err != nil {
  239. return nil, err
  240. }
  241. str = str[:len(str)-1]
  242. fields := strings.Split(str, ":")
  243. if len(fields) != 2 {
  244. return nil, errNumberRange
  245. }
  246. if len(fields[0]) == 0 && len(fields[1]) == 0 {
  247. return nil, errNumberRange
  248. }
  249. var left float64
  250. if len(fields[0]) > 0 {
  251. var err error
  252. if left, err = strconv.ParseFloat(fields[0], 64); err != nil {
  253. return nil, err
  254. }
  255. } else {
  256. left = -math.MaxFloat64
  257. }
  258. var right float64
  259. if len(fields[1]) > 0 {
  260. var err error
  261. if right, err = strconv.ParseFloat(fields[1], 64); err != nil {
  262. return nil, err
  263. }
  264. } else {
  265. right = math.MaxFloat64
  266. }
  267. if left > right {
  268. return nil, errNumberRange
  269. }
  270. // [2:2] valid
  271. // [2:2) invalid
  272. // (2:2] invalid
  273. // (2:2) invalid
  274. if left == right {
  275. if !leftInclude || !rightInclude {
  276. return nil, errNumberRange
  277. }
  278. }
  279. return &numberRange{
  280. left: left,
  281. leftInclude: leftInclude,
  282. right: right,
  283. rightInclude: rightInclude,
  284. }, nil
  285. }
  286. func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
  287. switch {
  288. case option == stringOption:
  289. fieldOpts.FromString = true
  290. case strings.HasPrefix(option, optionalOption):
  291. segs := strings.Split(option, equalToken)
  292. switch len(segs) {
  293. case 1:
  294. fieldOpts.Optional = true
  295. case 2:
  296. fieldOpts.Optional = true
  297. fieldOpts.OptionalDep = segs[1]
  298. default:
  299. return fmt.Errorf("field %s has wrong optional", fieldName)
  300. }
  301. case option == optionalOption:
  302. fieldOpts.Optional = true
  303. case strings.HasPrefix(option, optionsOption):
  304. segs := strings.Split(option, equalToken)
  305. if len(segs) != 2 {
  306. return fmt.Errorf("field %s has wrong options", fieldName)
  307. }
  308. fieldOpts.Options = parseOptions(segs[1])
  309. case strings.HasPrefix(option, defaultOption):
  310. segs := strings.Split(option, equalToken)
  311. if len(segs) != 2 {
  312. return fmt.Errorf("field %s has wrong default option", fieldName)
  313. }
  314. fieldOpts.Default = strings.TrimSpace(segs[1])
  315. case strings.HasPrefix(option, rangeOption):
  316. segs := strings.Split(option, equalToken)
  317. if len(segs) != 2 {
  318. return fmt.Errorf("field %s has wrong range", fieldName)
  319. }
  320. nr, err := parseNumberRange(segs[1])
  321. if err != nil {
  322. return err
  323. }
  324. fieldOpts.Range = nr
  325. }
  326. return nil
  327. }
  328. // parseOptions parses the given options in tag.
  329. // for example: `json:"name,options=foo|bar"` or `json:"name,options=[foo,bar]"`
  330. func parseOptions(val string) []string {
  331. if len(val) == 0 {
  332. return nil
  333. }
  334. if val[0] == leftSquareBracket {
  335. return parseGroupedSegments(val)
  336. }
  337. return strings.Split(val, optionSeparator)
  338. }
  339. func parseSegments(val string) []string {
  340. var segments []string
  341. var escaped, grouped bool
  342. var buf strings.Builder
  343. for _, ch := range val {
  344. if escaped {
  345. buf.WriteRune(ch)
  346. escaped = false
  347. continue
  348. }
  349. switch ch {
  350. case segmentSeparator:
  351. if grouped {
  352. buf.WriteRune(ch)
  353. } else {
  354. // need to trim spaces, but we cannot ignore empty string,
  355. // because the first segment stands for the key might be empty.
  356. // if ignored, the later tag will be used as the key.
  357. segments = append(segments, strings.TrimSpace(buf.String()))
  358. buf.Reset()
  359. }
  360. case escapeChar:
  361. if grouped {
  362. buf.WriteRune(ch)
  363. } else {
  364. escaped = true
  365. }
  366. case leftBracket, leftSquareBracket:
  367. buf.WriteRune(ch)
  368. grouped = true
  369. case rightBracket, rightSquareBracket:
  370. buf.WriteRune(ch)
  371. grouped = false
  372. default:
  373. buf.WriteRune(ch)
  374. }
  375. }
  376. last := strings.TrimSpace(buf.String())
  377. // ignore last empty string
  378. if len(last) > 0 {
  379. segments = append(segments, last)
  380. }
  381. return segments
  382. }
  383. func reprOfValue(val reflect.Value) string {
  384. switch vt := val.Interface().(type) {
  385. case bool:
  386. return strconv.FormatBool(vt)
  387. case error:
  388. return vt.Error()
  389. case float32:
  390. return strconv.FormatFloat(float64(vt), 'f', -1, 32)
  391. case float64:
  392. return strconv.FormatFloat(vt, 'f', -1, 64)
  393. case fmt.Stringer:
  394. return vt.String()
  395. case int:
  396. return strconv.Itoa(vt)
  397. case int8:
  398. return strconv.Itoa(int(vt))
  399. case int16:
  400. return strconv.Itoa(int(vt))
  401. case int32:
  402. return strconv.Itoa(int(vt))
  403. case int64:
  404. return strconv.FormatInt(vt, 10)
  405. case string:
  406. return vt
  407. case uint:
  408. return strconv.FormatUint(uint64(vt), 10)
  409. case uint8:
  410. return strconv.FormatUint(uint64(vt), 10)
  411. case uint16:
  412. return strconv.FormatUint(uint64(vt), 10)
  413. case uint32:
  414. return strconv.FormatUint(uint64(vt), 10)
  415. case uint64:
  416. return strconv.FormatUint(vt, 10)
  417. case []byte:
  418. return string(vt)
  419. default:
  420. return fmt.Sprint(val.Interface())
  421. }
  422. }
  423. func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interface{}) error {
  424. switch kind {
  425. case reflect.Bool:
  426. value.SetBool(v.(bool))
  427. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  428. value.SetInt(v.(int64))
  429. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  430. value.SetUint(v.(uint64))
  431. case reflect.Float32, reflect.Float64:
  432. value.SetFloat(v.(float64))
  433. case reflect.String:
  434. value.SetString(v.(string))
  435. default:
  436. return errUnsupportedType
  437. }
  438. return nil
  439. }
  440. func setValue(kind reflect.Kind, value reflect.Value, str string) error {
  441. if !value.CanSet() {
  442. return errValueNotSettable
  443. }
  444. value = ensureValue(value)
  445. v, err := convertType(kind, str)
  446. if err != nil {
  447. return err
  448. }
  449. return setMatchedPrimitiveValue(kind, value, v)
  450. }
  451. func structValueRequired(tag string, tp reflect.Type) (bool, error) {
  452. structCacheLock.RLock()
  453. val, ok := structRequiredCache[tp]
  454. structCacheLock.RUnlock()
  455. if ok {
  456. return val.required, val.err
  457. }
  458. required, err := implicitValueRequiredStruct(tag, tp)
  459. structCacheLock.Lock()
  460. structRequiredCache[tp] = requiredCacheValue{
  461. required: required,
  462. err: err,
  463. }
  464. structCacheLock.Unlock()
  465. return required, err
  466. }
  467. func toFloat64(v interface{}) (float64, bool) {
  468. switch val := v.(type) {
  469. case int:
  470. return float64(val), true
  471. case int8:
  472. return float64(val), true
  473. case int16:
  474. return float64(val), true
  475. case int32:
  476. return float64(val), true
  477. case int64:
  478. return float64(val), true
  479. case uint:
  480. return float64(val), true
  481. case uint8:
  482. return float64(val), true
  483. case uint16:
  484. return float64(val), true
  485. case uint32:
  486. return float64(val), true
  487. case uint64:
  488. return float64(val), true
  489. case float32:
  490. return float64(val), true
  491. case float64:
  492. return val, true
  493. default:
  494. return 0, false
  495. }
  496. }
  497. func usingDifferentKeys(key string, field reflect.StructField) bool {
  498. if len(field.Tag) > 0 {
  499. if _, ok := field.Tag.Lookup(key); !ok {
  500. return true
  501. }
  502. }
  503. return false
  504. }
  505. func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string, opts *fieldOptionsWithContext) error {
  506. if !value.CanSet() {
  507. return errValueNotSettable
  508. }
  509. v, err := convertType(kind, str)
  510. if err != nil {
  511. return err
  512. }
  513. if err := validateValueRange(v, opts); err != nil {
  514. return err
  515. }
  516. return setMatchedPrimitiveValue(kind, value, v)
  517. }
  518. func validateJsonNumberRange(v json.Number, opts *fieldOptionsWithContext) error {
  519. if opts == nil || opts.Range == nil {
  520. return nil
  521. }
  522. fv, err := v.Float64()
  523. if err != nil {
  524. return err
  525. }
  526. return validateNumberRange(fv, opts.Range)
  527. }
  528. func validateNumberRange(fv float64, nr *numberRange) error {
  529. if nr == nil {
  530. return nil
  531. }
  532. if (nr.leftInclude && fv < nr.left) || (!nr.leftInclude && fv <= nr.left) {
  533. return errNumberRange
  534. }
  535. if (nr.rightInclude && fv > nr.right) || (!nr.rightInclude && fv >= nr.right) {
  536. return errNumberRange
  537. }
  538. return nil
  539. }
  540. func validateValueInOptions(val interface{}, options []string) error {
  541. if len(options) > 0 {
  542. switch v := val.(type) {
  543. case string:
  544. if !stringx.Contains(options, v) {
  545. return fmt.Errorf(`error: value "%s" is not defined in options "%v"`, v, options)
  546. }
  547. default:
  548. if !stringx.Contains(options, Repr(v)) {
  549. return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
  550. }
  551. }
  552. }
  553. return nil
  554. }
  555. func validateValueRange(mapValue interface{}, opts *fieldOptionsWithContext) error {
  556. if opts == nil || opts.Range == nil {
  557. return nil
  558. }
  559. fv, ok := toFloat64(mapValue)
  560. if !ok {
  561. return errNumberRange
  562. }
  563. return validateNumberRange(fv, opts.Range)
  564. }