Kevin Wan 1 жил өмнө
parent
commit
cd0f3726ed

+ 89 - 82
core/stores/sqlx/orm.go

@@ -54,27 +54,39 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
 		}
 
 		valueField := reflect.Indirect(v).Field(i)
-		switch valueField.Kind() {
-		case reflect.Ptr:
-			if !valueField.CanInterface() {
-				return nil, ErrNotReadableValue
-			}
-			if valueField.IsNil() {
-				baseValueType := mapping.Deref(valueField.Type())
-				valueField.Set(reflect.New(baseValueType))
-			}
-			result[key] = valueField.Interface()
-		default:
-			if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
-				return nil, ErrNotReadableValue
-			}
-			result[key] = valueField.Addr().Interface()
+		valueData, err := getValueInterface(valueField)
+		if err != nil {
+			return nil, err
 		}
+
+		result[key] = valueData
 	}
 
 	return result, nil
 }
 
+func getValueInterface(value reflect.Value) (any, error) {
+	switch value.Kind() {
+	case reflect.Ptr:
+		if !value.CanInterface() {
+			return nil, ErrNotReadableValue
+		}
+
+		if value.IsNil() {
+			baseValueType := mapping.Deref(value.Type())
+			value.Set(reflect.New(baseValueType))
+		}
+
+		return value.Interface(), nil
+	default:
+		if !value.CanAddr() || !value.Addr().CanInterface() {
+			return nil, ErrNotReadableValue
+		}
+
+		return value.Addr().Interface(), nil
+	}
+}
+
 func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) {
 	fields := unwrapFields(v)
 	if strict && len(columns) < len(fields) {
@@ -88,24 +100,18 @@ func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([
 
 	values := make([]any, len(columns))
 	if len(taggedMap) == 0 {
+		if len(fields) < len(values) {
+			return nil, ErrNotMatchDestination
+		}
+
 		for i := 0; i < len(values); i++ {
 			valueField := fields[i]
-			switch valueField.Kind() {
-			case reflect.Ptr:
-				if !valueField.CanInterface() {
-					return nil, ErrNotReadableValue
-				}
-				if valueField.IsNil() {
-					baseValueType := mapping.Deref(valueField.Type())
-					valueField.Set(reflect.New(baseValueType))
-				}
-				values[i] = valueField.Interface()
-			default:
-				if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
-					return nil, ErrNotReadableValue
-				}
-				values[i] = valueField.Addr().Interface()
+			valueData, err := getValueInterface(valueField)
+			if err != nil {
+				return nil, err
 			}
+
+			values[i] = valueData
 		}
 	} else {
 		for i, column := range columns {
@@ -152,11 +158,11 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error {
 		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
 		reflect.Float32, reflect.Float64,
 		reflect.String:
-		if rve.CanSet() {
-			return scanner.Scan(v)
+		if !rve.CanSet() {
+			return ErrNotSettable
 		}
 
-		return ErrNotSettable
+		return scanner.Scan(v)
 	case reflect.Struct:
 		columns, err := scanner.Columns()
 		if err != nil {
@@ -183,69 +189,66 @@ func unmarshalRows(v any, scanner rowsScanner, strict bool) error {
 	rt := reflect.TypeOf(v)
 	rte := rt.Elem()
 	rve := rv.Elem()
+	if !rve.CanSet() {
+		return ErrNotSettable
+	}
+
 	switch rte.Kind() {
 	case reflect.Slice:
-		if rve.CanSet() {
-			ptr := rte.Elem().Kind() == reflect.Ptr
-			appendFn := func(item reflect.Value) {
-				if ptr {
-					rve.Set(reflect.Append(rve, item))
-				} else {
-					rve.Set(reflect.Append(rve, reflect.Indirect(item)))
-				}
+		ptr := rte.Elem().Kind() == reflect.Ptr
+		appendFn := func(item reflect.Value) {
+			if ptr {
+				rve.Set(reflect.Append(rve, item))
+			} else {
+				rve.Set(reflect.Append(rve, reflect.Indirect(item)))
 			}
-			fillFn := func(value any) error {
-				if rve.CanSet() {
-					if err := scanner.Scan(value); err != nil {
-						return err
-					}
-
-					appendFn(reflect.ValueOf(value))
-					return nil
-				}
-				return ErrNotSettable
+		}
+		fillFn := func(value any) error {
+			if err := scanner.Scan(value); err != nil {
+				return err
 			}
 
-			base := mapping.Deref(rte.Elem())
-			switch base.Kind() {
-			case reflect.Bool,
-				reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
-				reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
-				reflect.Float32, reflect.Float64,
-				reflect.String:
-				for scanner.Next() {
-					value := reflect.New(base)
-					if err := fillFn(value.Interface()); err != nil {
-						return err
-					}
+			appendFn(reflect.ValueOf(value))
+			return nil
+		}
+
+		base := mapping.Deref(rte.Elem())
+		switch base.Kind() {
+		case reflect.Bool,
+			reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
+			reflect.Float32, reflect.Float64,
+			reflect.String:
+			for scanner.Next() {
+				value := reflect.New(base)
+				if err := fillFn(value.Interface()); err != nil {
+					return err
 				}
-			case reflect.Struct:
-				columns, err := scanner.Columns()
+			}
+		case reflect.Struct:
+			columns, err := scanner.Columns()
+			if err != nil {
+				return err
+			}
+
+			for scanner.Next() {
+				value := reflect.New(base)
+				values, err := mapStructFieldsIntoSlice(value, columns, strict)
 				if err != nil {
 					return err
 				}
 
-				for scanner.Next() {
-					value := reflect.New(base)
-					values, err := mapStructFieldsIntoSlice(value, columns, strict)
-					if err != nil {
-						return err
-					}
-
-					if err := scanner.Scan(values...); err != nil {
-						return err
-					}
-
-					appendFn(value)
+				if err := scanner.Scan(values...); err != nil {
+					return err
 				}
-			default:
-				return ErrUnsupportedValueType
-			}
 
-			return nil
+				appendFn(value)
+			}
+		default:
+			return ErrUnsupportedValueType
 		}
 
-		return ErrNotSettable
+		return nil
 	default:
 		return ErrUnsupportedValueType
 	}
@@ -257,6 +260,10 @@ func unwrapFields(v reflect.Value) []reflect.Value {
 
 	for i := 0; i < indirect.NumField(); i++ {
 		child := indirect.Field(i)
+		if !child.CanSet() {
+			continue
+		}
+
 		if child.Kind() == reflect.Ptr && child.IsNil() {
 			baseValueType := mapping.Deref(child.Type())
 			child.Set(reflect.New(baseValueType))

+ 215 - 28
core/stores/sqlx/orm_test.go

@@ -22,6 +22,18 @@ func TestUnmarshalRowBool(t *testing.T) {
 		}, "select value from users where user=?", "anyone"))
 		assert.True(t, value)
 	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		var value struct {
+			Value bool `db:"value"`
+		}
+		assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRow(value, rows, true)
+		}, "select value from users where user=?", "anyone"))
+	})
 }
 
 func TestUnmarshalRowBoolNotSettable(t *testing.T) {
@@ -207,12 +219,12 @@ func TestUnmarshalRowString(t *testing.T) {
 }
 
 func TestUnmarshalRowStruct(t *testing.T) {
-	value := new(struct {
-		Name string
-		Age  int
-	})
-
 	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		value := new(struct {
+			Name string
+			Age  int
+		})
+
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
@@ -222,22 +234,110 @@ func TestUnmarshalRowStruct(t *testing.T) {
 		assert.Equal(t, "liao", value.Name)
 		assert.Equal(t, 5, value.Age)
 	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		value := new(struct {
+			Name string
+			Age  int
+		})
+
+		errAny := errors.New("any error")
+		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRow(value, &mockedScanner{
+				colErr: errAny,
+				next:   1,
+			}, true)
+		}, "select name, age from users where user=?", "anyone"), errAny)
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		value := new(struct {
+			Name string
+			age  *int
+		})
+
+		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRow(value, rows, true)
+		}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		type myString chan int
+		var value myString
+		assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRow(&value, rows, true)
+		}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
+	})
 }
 
 func TestUnmarshalRowStructWithTags(t *testing.T) {
-	value := new(struct {
-		Age  int    `db:"age"`
-		Name string `db:"name"`
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		value := new(struct {
+			Age  int    `db:"age"`
+			Name string `db:"name"`
+		})
+
+		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRow(value, rows, true)
+		}, "select name, age from users where user=?", "anyone"))
+		assert.Equal(t, "liao", value.Name)
+		assert.Equal(t, 5, value.Age)
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		value := new(struct {
+			age  *int   `db:"age"`
+			Name string `db:"name"`
+		})
+
+		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRow(value, rows, true)
+		}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		var value struct {
+			Age  *int    `db:"age"`
+			Name *string `db:"name"`
+		}
+
+		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRow(&value, rows, true)
+		}, "select name, age from users where user=?", "anyone"))
+		assert.Equal(t, "liao", *value.Name)
+		assert.Equal(t, 5, *value.Age)
 	})
 
 	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		value := new(struct {
+			Age  int `db:"age"`
+			Name string
+		})
+
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 
 		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
 			return unmarshalRow(value, rows, true)
 		}, "select name, age from users where user=?", "anyone"))
-		assert.Equal(t, "liao", value.Name)
 		assert.Equal(t, 5, value.Age)
 	})
 }
@@ -270,6 +370,42 @@ func TestUnmarshalRowsBool(t *testing.T) {
 		}, "select value from users where user=?", "anyone"))
 		assert.EqualValues(t, expect, value)
 	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		var value []bool
+		assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRows(value, rows, true)
+		}, "select value from users where user=?", "anyone"))
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		var value struct {
+			value []bool `db:"value"`
+		}
+		assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRows(&value, rows, true)
+		}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		var value []bool
+		errAny := errors.New("any")
+		assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRows(&value, &mockedScanner{
+				scanErr: errAny,
+				next:    1,
+			}, true)
+		}, "select value from users where user=?", "anyone"), errAny)
+	})
 }
 
 func TestUnmarshalRowsInt(t *testing.T) {
@@ -679,25 +815,25 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
 }
 
 func TestUnmarshalRowsStruct(t *testing.T) {
-	expect := []struct {
-		Name string
-		Age  int64
-	}{
-		{
-			Name: "first",
-			Age:  2,
-		},
-		{
-			Name: "second",
-			Age:  3,
-		},
-	}
-	var value []struct {
-		Name string
-		Age  int64
-	}
-
 	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		expect := []struct {
+			Name string
+			Age  int64
+		}{
+			{
+				Name: "first",
+				Age:  2,
+			},
+			{
+				Name: "second",
+				Age:  3,
+			},
+		}
+		var value []struct {
+			Name string
+			Age  int64
+		}
+
 		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
 		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
 		assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -709,6 +845,56 @@ func TestUnmarshalRowsStruct(t *testing.T) {
 			assert.Equal(t, each.Age, value[i].Age)
 		}
 	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		var value []struct {
+			Name string
+			Age  int64
+		}
+
+		errAny := errors.New("any error")
+		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+		assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRows(&value, &mockedScanner{
+				colErr: errAny,
+				next:   1,
+			}, true)
+		}, "select name, age from users where user=?", "anyone"), errAny)
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		var value []struct {
+			Name string
+			Age  int64
+		}
+
+		errAny := errors.New("any error")
+		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+		assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRows(&value, &mockedScanner{
+				cols:    []string{"name", "age"},
+				scanErr: errAny,
+				next:    1,
+			}, true)
+		}, "select name, age from users where user=?", "anyone"), errAny)
+	})
+
+	dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		var value []chan int
+
+		errAny := errors.New("any error")
+		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+		assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
+			return unmarshalRows(&value, &mockedScanner{
+				cols:    []string{"name", "age"},
+				scanErr: errAny,
+				next:    1,
+			}, true)
+		}, "select name, age from users where user=?", "anyone"), ErrUnsupportedValueType)
+	})
 }
 
 func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
@@ -1163,6 +1349,7 @@ func TestAnonymousStructPrError(t *testing.T) {
 }
 
 type mockedScanner struct {
+	cols    []string
 	colErr  error
 	scanErr error
 	err     error
@@ -1170,7 +1357,7 @@ type mockedScanner struct {
 }
 
 func (m *mockedScanner) Columns() ([]string, error) {
-	return nil, m.colErr
+	return m.cols, m.colErr
 }
 
 func (m *mockedScanner) Err() error {