Browse Source

chore: refactor orm code (#3015)

Kevin Wan 2 years ago
parent
commit
cca45be3c5
2 changed files with 45 additions and 21 deletions
  1. 11 8
      core/stores/sqlx/orm.go
  2. 34 13
      core/stores/sqlx/orm_test.go

+ 11 - 8
core/stores/sqlx/orm.go

@@ -34,18 +34,21 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
 	result := make(map[string]any, size)
 
 	for i := 0; i < size; i++ {
-		if (rt.Field(i).Type.Kind() == reflect.Struct || rt.Field(i).Type.Kind() == reflect.Ptr) && rt.Field(i).Anonymous {
-			r, e := getTaggedFieldValueMap(reflect.Indirect(v).Field(i))
-			if e != nil {
-				return nil, e
+		field := rt.Field(i)
+		if field.Anonymous && mapping.Deref(field.Type).Kind() == reflect.Struct {
+			inner, err := getTaggedFieldValueMap(reflect.Indirect(v).Field(i))
+			if err != nil {
+				return nil, err
 			}
-			for i2, i3 := range r {
-				result[i2] = i3
+
+			for key, val := range inner {
+				result[key] = val
 			}
+
 			continue
 		}
 
-		key := parseTagName(rt.Field(i))
+		key := parseTagName(field)
 		if len(key) == 0 {
 			continue
 		}
@@ -125,7 +128,7 @@ func parseTagName(field reflect.StructField) string {
 	}
 
 	options := strings.Split(key, ",")
-	return options[0]
+	return strings.TrimSpace(options[0])
 }
 
 func unmarshalRow(v any, scanner rowsScanner, strict bool) error {

+ 34 - 13
core/stores/sqlx/orm_test.go

@@ -1069,19 +1069,19 @@ func TestAnonymousStructPr(t *testing.T) {
 				String: "",
 				Valid:  false,
 			},
-			ClassName:  "实验班",
-			Discipline: "数学",
+			ClassName:  "experimental class",
+			Discipline: "math",
 			Score:      100,
 		},
 		{
 			Name: "second",
 			Age:  3,
 			Grade: sql.NullString{
-				String: "大一",
+				String: "grade one",
 				Valid:  true,
 			},
-			ClassName:  "三班二年",
-			Discipline: "语文",
+			ClassName:  "class three grade two",
+			Discipline: "chinese",
 			Score:      99,
 		},
 	}
@@ -1092,12 +1092,22 @@ func TestAnonymousStructPr(t *testing.T) {
 	}
 
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
-		rs := sqlmock.NewRows([]string{"name", "age", "grade", "discipline", "class_name", "score"}).AddRow("first", 2, nil, "数学", "实验班", 100).
-			AddRow("second", 3, "大一", "语文", "三班二年", 99)
-		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+		rs := sqlmock.NewRows([]string{
+			"name",
+			"age",
+			"grade",
+			"discipline",
+			"class_name",
+			"score",
+		}).
+			AddRow("first", 2, nil, "math", "experimental class", 100).
+			AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
+		mock.ExpectQuery("select (.+) from users where user=?").
+			WithArgs("anyone").WillReturnRows(rs)
 		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
-		}, "select name, age,grade,discipline,class_name,score from users where user=?", "anyone"))
+		}, "select name, age,grade,discipline,class_name,score from users where user=?",
+			"anyone"))
 
 		for i, each := range expect {
 			assert.Equal(t, each.Name, value[i].Name)
@@ -1109,6 +1119,7 @@ func TestAnonymousStructPr(t *testing.T) {
 		}
 	})
 }
+
 func TestAnonymousStructPrError(t *testing.T) {
 	type Score struct {
 		Discipline string `db:"discipline"`
@@ -1129,12 +1140,22 @@ func TestAnonymousStructPrError(t *testing.T) {
 	}
 
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
-		rs := sqlmock.NewRows([]string{"name", "age", "grade", "discipline", "class_name", "score"}).AddRow("first", 2, nil, "数学", "实验班", 100).
-			AddRow("second", 3, "大一", "语文", "三班二年", 99)
-		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+		rs := sqlmock.NewRows([]string{
+			"name",
+			"age",
+			"grade",
+			"discipline",
+			"class_name",
+			"score",
+		}).
+			AddRow("first", 2, nil, "math", "experimental class", 100).
+			AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
+		mock.ExpectQuery("select (.+) from users where user=?").
+			WithArgs("anyone").WillReturnRows(rs)
 		assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRows(&value, rows, true)
-		}, "select name, age,grade,discipline,class_name,score from users where user=?", "anyone"))
+		}, "select name, age,grade,discipline,class_name,score from users where user=?",
+			"anyone"))
 		if len(value) > 0 {
 			assert.Equal(t, value[0].score, 0)
 		}