Browse Source

add more tests

kevin 4 năm trước cách đây
mục cha
commit
6749c5b94a

+ 2 - 4
core/collection/set.go

@@ -15,6 +15,7 @@ const (
 	stringType
 	stringType
 )
 )
 
 
+// Set is not thread-safe, for concurrent use, make sure to use it with synchronization.
 type Set struct {
 type Set struct {
 	data map[interface{}]lang.PlaceholderType
 	data map[interface{}]lang.PlaceholderType
 	tp   int
 	tp   int
@@ -182,10 +183,7 @@ func (s *Set) add(i interface{}) {
 }
 }
 
 
 func (s *Set) setType(i interface{}) {
 func (s *Set) setType(i interface{}) {
-	if s.tp != untyped {
-		return
-	}
-
+	// s.tp can only be untyped here
 	switch i.(type) {
 	switch i.(type) {
 	case int:
 	case int:
 		s.tp = intType
 		s.tp = intType

+ 53 - 0
core/collection/set_test.go

@@ -5,8 +5,13 @@ import (
 	"testing"
 	"testing"
 
 
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
+	"github.com/tal-tech/go-zero/core/logx"
 )
 )
 
 
+func init() {
+	logx.Disable()
+}
+
 func BenchmarkRawSet(b *testing.B) {
 func BenchmarkRawSet(b *testing.B) {
 	m := make(map[interface{}]struct{})
 	m := make(map[interface{}]struct{})
 	for i := 0; i < b.N; i++ {
 	for i := 0; i < b.N; i++ {
@@ -147,3 +152,51 @@ func TestCount(t *testing.T) {
 	// then
 	// then
 	assert.Equal(t, set.Count(), 3)
 	assert.Equal(t, set.Count(), 3)
 }
 }
+
+func TestKeysIntMismatch(t *testing.T) {
+	set := NewSet()
+	set.add(int64(1))
+	set.add(2)
+	vals := set.KeysInt()
+	assert.EqualValues(t, []int{2}, vals)
+}
+
+func TestKeysInt64Mismatch(t *testing.T) {
+	set := NewSet()
+	set.add(1)
+	set.add(int64(2))
+	vals := set.KeysInt64()
+	assert.EqualValues(t, []int64{2}, vals)
+}
+
+func TestKeysUintMismatch(t *testing.T) {
+	set := NewSet()
+	set.add(1)
+	set.add(uint(2))
+	vals := set.KeysUint()
+	assert.EqualValues(t, []uint{2}, vals)
+}
+
+func TestKeysUint64Mismatch(t *testing.T) {
+	set := NewSet()
+	set.add(1)
+	set.add(uint64(2))
+	vals := set.KeysUint64()
+	assert.EqualValues(t, []uint64{2}, vals)
+}
+
+func TestKeysStrMismatch(t *testing.T) {
+	set := NewSet()
+	set.add(1)
+	set.add("2")
+	vals := set.KeysStr()
+	assert.EqualValues(t, []string{"2"}, vals)
+}
+
+func TestSetType(t *testing.T) {
+	set := NewUnmanagedSet()
+	set.add(1)
+	set.add("2")
+	vals := set.Keys()
+	assert.ElementsMatch(t, []interface{}{1, "2"}, vals)
+}

+ 46 - 0
core/stores/sqlx/orm_test.go

@@ -22,6 +22,18 @@ func TestUnmarshalRowBool(t *testing.T) {
 	})
 	})
 }
 }
 
 
+func TestUnmarshalRowBoolNotSettable(t *testing.T) {
+	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		var value bool
+		assert.NotNil(t, query(db, func(rows *sql.Rows) error {
+			return unmarshalRow(value, rows, true)
+		}, "select value from users where user=?", "anyone"))
+	})
+}
+
 func TestUnmarshalRowInt(t *testing.T) {
 func TestUnmarshalRowInt(t *testing.T) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
 		rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
@@ -228,6 +240,40 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
 	})
 	})
 }
 }
 
 
+func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
+	var value = new(struct {
+		Age  *int   `db:"age"`
+		Name string `db:"name"`
+	})
+
+	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		assert.NotNil(t, query(db, func(rows *sql.Rows) error {
+			return unmarshalRow(value, rows, true)
+		}, "select name, age from users where user=?", "anyone"))
+	})
+}
+
+func TestUnmarshalRowStructWithTagsPtr(t *testing.T) {
+	var value = new(struct {
+		Age  *int   `db:"age"`
+		Name string `db:"name"`
+	})
+
+	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
+		rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
+		mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
+
+		assert.Nil(t, query(db, func(rows *sql.Rows) error {
+			return unmarshalRow(value, rows, true)
+		}, "select name, age from users where user=?", "anyone"))
+		assert.Equal(t, "liao", value.Name)
+		assert.Equal(t, 5, *value.Age)
+	})
+}
+
 func TestUnmarshalRowsBool(t *testing.T) {
 func TestUnmarshalRowsBool(t *testing.T) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 	runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
 		var expect = []bool{true, false}
 		var expect = []bool{true, false}