Răsfoiți Sursa

fix QueryRowsPartial getTaggedFieldValueMap func (#2884)

Co-authored-by: yongkun.xiong <weilone@vip.qq.com>
YK.xiong 2 ani în urmă
părinte
comite
e735915d89
2 a modificat fișierele cu 112 adăugiri și 1 ștergeri
  1. 12 1
      core/stores/sqlx/orm.go
  2. 100 0
      core/stores/sqlx/orm_test.go

+ 12 - 1
core/stores/sqlx/orm.go

@@ -34,9 +34,20 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
 	result := make(map[string]any, size)
 
 	for i := 0; i < size; i++ {
+		if (rt.Field(i).Type.Kind() == reflect.Struct || rt.Field(i).Type.Kind() == reflect.Ptr) && rt.Field(i).Anonymous {
+			r, e := getTaggedFieldValueMap(reflect.Indirect(v).Field(i))
+			if e != nil {
+				return nil, e
+			}
+			for i2, i3 := range r {
+				result[i2] = i3
+			}
+			continue
+		}
+
 		key := parseTagName(rt.Field(i))
 		if len(key) == 0 {
-			return nil, nil
+			continue
 		}
 
 		valueField := reflect.Indirect(v).Field(i)

+ 100 - 0
core/stores/sqlx/orm_test.go

@@ -1041,6 +1041,106 @@ func TestUnmarshalRowError(t *testing.T) {
 	}
 }
 
+func TestAnonymousStructPr(t *testing.T) {
+	type Score struct {
+		Discipline string `db:"discipline"`
+		Score      uint   `db:"score"`
+	}
+	type ClassType struct {
+		Grade     sql.NullString `db:"grade"`
+		ClassName *string        `db:"class_name"`
+	}
+	type Class struct {
+		*ClassType
+		Score
+	}
+	expect := []*struct {
+		Name       string
+		Age        int64
+		Grade      sql.NullString
+		Discipline string
+		Score      uint
+		ClassName  string
+	}{
+		{
+			Name: "first",
+			Age:  2,
+			Grade: sql.NullString{
+				String: "",
+				Valid:  false,
+			},
+			ClassName:  "实验班",
+			Discipline: "数学",
+			Score:      100,
+		},
+		{
+			Name: "second",
+			Age:  3,
+			Grade: sql.NullString{
+				String: "大一",
+				Valid:  true,
+			},
+			ClassName:  "三班二年",
+			Discipline: "语文",
+			Score:      99,
+		},
+	}
+	var value []*struct {
+		Age int64 `db:"age"`
+		Class
+		Name string `db:"name"`
+	}
+
+	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		rs := sqlmock.NewRows([]string{"name", "age", "grade", "discipline", "class_name", "score"}).AddRow("first", 2, nil, "数学", "实验班", 100).
+			AddRow("second", 3, "大一", "语文", "三班二年", 99)
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRows(&value, rows, true)
+		}, "select name, age,grade,discipline,class_name,score from users where user=?", "anyone"))
+
+		for i, each := range expect {
+			assert.Equal(t, each.Name, value[i].Name)
+			assert.Equal(t, each.Age, value[i].Age)
+			assert.Equal(t, each.ClassName, *value[i].Class.ClassName)
+			assert.Equal(t, each.Discipline, value[i].Score.Discipline)
+			assert.Equal(t, each.Score, value[i].Score.Score)
+			assert.Equal(t, each.Grade, value[i].Class.Grade)
+		}
+	})
+}
+func TestAnonymousStructPrError(t *testing.T) {
+	type Score struct {
+		Discipline string `db:"discipline"`
+		score      uint   `db:"score"`
+	}
+	type ClassType struct {
+		Grade     sql.NullString `db:"grade"`
+		ClassName *string        `db:"class_name"`
+	}
+	type Class struct {
+		*ClassType
+		Score
+	}
+	var value []*struct {
+		Age int64 `db:"age"`
+		Class
+		Name string `db:"name"`
+	}
+
+	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		rs := sqlmock.NewRows([]string{"name", "age", "grade", "discipline", "class_name", "score"}).AddRow("first", 2, nil, "数学", "实验班", 100).
+			AddRow("second", 3, "大一", "语文", "三班二年", 99)
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+		assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRows(&value, rows, true)
+		}, "select name, age,grade,discipline,class_name,score from users where user=?", "anyone"))
+		if len(value) > 0 {
+			assert.Equal(t, value[0].score, 0)
+		}
+	})
+}
+
 func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
 	logx.Disable()