Browse Source

chore: add more tests (#3279)

Kevin Wan 1 year ago
parent
commit
4a2a8d9e45

+ 5 - 8
core/conf/config_test.go

@@ -35,11 +35,11 @@ func TestConfigJson(t *testing.T) {
 	"c": "${FOO}",
 	"d": "abcd!@#$112"
 }`
+	t.Setenv("FOO", "2")
+
 	for _, test := range tests {
 		test := test
 		t.Run(test, func(t *testing.T) {
-			os.Setenv("FOO", "2")
-			defer os.Unsetenv("FOO")
 			tmpfile, err := createTempFile(test, text)
 			assert.Nil(t, err)
 			defer os.Remove(tmpfile)
@@ -81,8 +81,7 @@ b = 1
 c = "${FOO}"
 d = "abcd!@#$112"
 `
-	os.Setenv("FOO", "2")
-	defer os.Unsetenv("FOO")
+	t.Setenv("FOO", "2")
 	tmpfile, err := createTempFile(".toml", text)
 	assert.Nil(t, err)
 	defer os.Remove(tmpfile)
@@ -207,8 +206,7 @@ b = 1
 c = "${FOO}"
 d = "abcd!@#112"
 `
-	os.Setenv("FOO", "2")
-	defer os.Unsetenv("FOO")
+	t.Setenv("FOO", "2")
 	tmpfile, err := createTempFile(".toml", text)
 	assert.Nil(t, err)
 	defer os.Remove(tmpfile)
@@ -239,11 +237,10 @@ func TestConfigJsonEnv(t *testing.T) {
 	"c": "${FOO}",
 	"d": "abcd!@#$a12 3"
 }`
+	t.Setenv("FOO", "2")
 	for _, test := range tests {
 		test := test
 		t.Run(test, func(t *testing.T) {
-			os.Setenv("FOO", "2")
-			defer os.Unsetenv("FOO")
 			tmpfile, err := createTempFile(test, text)
 			assert.Nil(t, err)
 			defer os.Remove(tmpfile)

+ 1 - 2
core/conf/properties_test.go

@@ -45,8 +45,7 @@ func TestPropertiesEnv(t *testing.T) {
 	assert.Nil(t, err)
 	defer os.Remove(tmpfile)
 
-	os.Setenv("FOO", "2")
-	defer os.Unsetenv("FOO")
+	t.Setenv("FOO", "2")
 
 	props, err := LoadProperties(tmpfile, UseEnv())
 	assert.Nil(t, err)

+ 11 - 16
core/mapping/unmarshaler.go

@@ -513,8 +513,8 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
 	vp valueWithParent, opts *fieldOptionsWithContext, fullName string) error {
 	derefedFieldType := Deref(fieldType)
 	typeKind := derefedFieldType.Kind()
-	valueKind := reflect.TypeOf(vp.value).Kind()
 	mapValue := vp.value
+	valueKind := reflect.TypeOf(mapValue).Kind()
 
 	switch {
 	case valueKind == reflect.Map && typeKind == reflect.Struct:
@@ -527,6 +527,8 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
 			current: mapValuer(mv),
 			parent:  vp.parent,
 		}, fullName)
+	case typeKind == reflect.Slice && valueKind == reflect.Slice:
+		return u.fillSlice(fieldType, value, mapValue)
 	case valueKind == reflect.Map && typeKind == reflect.Map:
 		return u.fillMap(fieldType, value, mapValue)
 	case valueKind == reflect.String && typeKind == reflect.Map:
@@ -545,23 +547,16 @@ func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflec
 	typeKind := Deref(fieldType).Kind()
 	valueKind := reflect.TypeOf(mapValue).Kind()
 
-	switch {
-	case typeKind == reflect.Slice && valueKind == reflect.Slice:
-		return u.fillSlice(fieldType, value, mapValue)
-	case typeKind == reflect.Map && valueKind == reflect.Map:
-		return u.fillMap(fieldType, value, mapValue)
+	switch v := mapValue.(type) {
+	case json.Number:
+		return u.processFieldPrimitiveWithJSONNumber(fieldType, value, v, opts, fullName)
 	default:
-		switch v := mapValue.(type) {
-		case json.Number:
-			return u.processFieldPrimitiveWithJSONNumber(fieldType, value, v, opts, fullName)
-		default:
-			if typeKind == valueKind {
-				if err := validateValueInOptions(mapValue, opts.options()); err != nil {
-					return err
-				}
-
-				return fillWithSameType(fieldType, value, mapValue, opts)
+		if typeKind == valueKind {
+			if err := validateValueInOptions(mapValue, opts.options()); err != nil {
+				return err
 			}
+
+			return fillWithSameType(fieldType, value, mapValue, opts)
 		}
 	}
 

+ 138 - 55
core/mapping/unmarshaler_test.go

@@ -3,7 +3,6 @@ package mapping
 import (
 	"encoding/json"
 	"fmt"
-	"os"
 	"strconv"
 	"strings"
 	"testing"
@@ -1454,18 +1453,42 @@ func TestUnmarshalMapOfStructError(t *testing.T) {
 }
 
 func TestUnmarshalSlice(t *testing.T) {
-	m := map[string]any{
-		"Ids": []any{"first", "second"},
-	}
-	var v struct {
-		Ids []string
-	}
-	ast := assert.New(t)
-	if ast.NoError(UnmarshalKey(m, &v)) {
-		ast.Equal(2, len(v.Ids))
-		ast.Equal("first", v.Ids[0])
-		ast.Equal("second", v.Ids[1])
-	}
+	t.Run("slice of string", func(t *testing.T) {
+		m := map[string]any{
+			"Ids": []any{"first", "second"},
+		}
+		var v struct {
+			Ids []string
+		}
+		ast := assert.New(t)
+		if ast.NoError(UnmarshalKey(m, &v)) {
+			ast.Equal(2, len(v.Ids))
+			ast.Equal("first", v.Ids[0])
+			ast.Equal("second", v.Ids[1])
+		}
+	})
+
+	t.Run("slice with type mismatch", func(t *testing.T) {
+		var v struct {
+			Ids string
+		}
+		assert.Error(t, NewUnmarshaler(jsonTagKey).Unmarshal([]any{1, 2}, &v))
+	})
+
+	t.Run("slice", func(t *testing.T) {
+		var v []int
+		ast := assert.New(t)
+		if ast.NoError(NewUnmarshaler(jsonTagKey).Unmarshal([]any{1, 2}, &v)) {
+			ast.Equal(2, len(v))
+			ast.Equal(1, v[0])
+			ast.Equal(2, v[1])
+		}
+	})
+
+	t.Run("slice with unsupported type", func(t *testing.T) {
+		var v int
+		assert.Error(t, NewUnmarshaler(jsonTagKey).Unmarshal(1, &v))
+	})
 }
 
 func TestUnmarshalSliceOfStruct(t *testing.T) {
@@ -3529,8 +3552,7 @@ func TestUnmarshal_EnvString(t *testing.T) {
 		envName = "TEST_NAME_STRING"
 		envVal  = "this is a name"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3547,8 +3569,7 @@ func TestUnmarshal_EnvStringOverwrite(t *testing.T) {
 		envName = "TEST_NAME_STRING"
 		envVal  = "this is a name"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(map[string]any{
@@ -3567,8 +3588,7 @@ func TestUnmarshal_EnvInt(t *testing.T) {
 		envName = "TEST_NAME_INT"
 		envVal  = "123"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3585,8 +3605,7 @@ func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
 		envName = "TEST_NAME_INT"
 		envVal  = "123"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(map[string]any{
@@ -3605,8 +3624,7 @@ func TestUnmarshal_EnvFloat(t *testing.T) {
 		envName = "TEST_NAME_FLOAT"
 		envVal  = "123.45"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3623,8 +3641,7 @@ func TestUnmarshal_EnvFloatOverwrite(t *testing.T) {
 		envName = "TEST_NAME_FLOAT"
 		envVal  = "123.45"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(map[string]any{
@@ -3643,8 +3660,7 @@ func TestUnmarshal_EnvBoolTrue(t *testing.T) {
 		envName = "TEST_NAME_BOOL_TRUE"
 		envVal  = "true"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3661,8 +3677,7 @@ func TestUnmarshal_EnvBoolFalse(t *testing.T) {
 		envName = "TEST_NAME_BOOL_FALSE"
 		envVal  = "false"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3679,8 +3694,7 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
 		envName = "TEST_NAME_BOOL_BAD"
 		envVal  = "bad"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -3695,8 +3709,7 @@ func TestUnmarshal_EnvDuration(t *testing.T) {
 		envName = "TEST_NAME_DURATION"
 		envVal  = "1s"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3713,8 +3726,7 @@ func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
 		envName = "TEST_NAME_BAD_DURATION"
 		envVal  = "bad"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -3729,8 +3741,7 @@ func TestUnmarshal_EnvWithOptions(t *testing.T) {
 		envName = "TEST_NAME_ENV_OPTIONS_MATCH"
 		envVal  = "123"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
@@ -3747,8 +3758,7 @@ func TestUnmarshal_EnvWithOptionsWrongValueBool(t *testing.T) {
 		envName = "TEST_NAME_ENV_OPTIONS_BOOL"
 		envVal  = "false"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -3763,8 +3773,7 @@ func TestUnmarshal_EnvWithOptionsWrongValueDuration(t *testing.T) {
 		envName = "TEST_NAME_ENV_OPTIONS_DURATION"
 		envVal  = "4s"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -3779,8 +3788,7 @@ func TestUnmarshal_EnvWithOptionsWrongValueNumber(t *testing.T) {
 		envName = "TEST_NAME_ENV_OPTIONS_AGE"
 		envVal  = "30"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -3795,8 +3803,7 @@ func TestUnmarshal_EnvWithOptionsWrongValueString(t *testing.T) {
 		envName = "TEST_NAME_ENV_OPTIONS_STRING"
 		envVal  = "this is a name"
 	)
-	os.Setenv(envName, envVal)
-	defer os.Unsetenv(envName)
+	t.Setenv(envName, envVal)
 
 	var v Value
 	assert.Error(t, UnmarshalKey(emptyMap, &v))
@@ -4408,18 +4415,80 @@ func TestFillDefaultUnmarshal(t *testing.T) {
 }
 
 func Test_UnmarshalMap(t *testing.T) {
-	type Customer struct {
-		Names map[int]string `key:"names"`
-	}
+	t.Run("type mismatch", func(t *testing.T) {
+		type Customer struct {
+			Names map[int]string `key:"names"`
+		}
 
-	input := map[string]any{
-		"names": map[string]any{
-			"19": "Tom",
-		},
-	}
+		input := map[string]any{
+			"names": map[string]any{
+				"19": "Tom",
+			},
+		}
+
+		var customer Customer
+		assert.ErrorIs(t, UnmarshalKey(input, &customer), errTypeMismatch)
+	})
+
+	t.Run("map type mismatch", func(t *testing.T) {
+		type Customer struct {
+			Names struct {
+				Values map[string]string
+			} `key:"names"`
+		}
+
+		input := map[string]any{
+			"names": map[string]string{
+				"19": "Tom",
+			},
+		}
+
+		var customer Customer
+		assert.ErrorIs(t, UnmarshalKey(input, &customer), errTypeMismatch)
+	})
+}
+
+func TestGetValueWithChainedKeys(t *testing.T) {
+	t.Run("no key", func(t *testing.T) {
+		_, ok := getValueWithChainedKeys(nil, []string{})
+		assert.False(t, ok)
+	})
 
-	var customer Customer
-	assert.ErrorIs(t, UnmarshalKey(input, &customer), errTypeMismatch)
+	t.Run("one key", func(t *testing.T) {
+		v, ok := getValueWithChainedKeys(mockValuerWithParent{
+			value: "bar",
+			ok:    true,
+		}, []string{"foo"})
+		assert.True(t, ok)
+		assert.Equal(t, "bar", v)
+	})
+
+	t.Run("two keys", func(t *testing.T) {
+		v, ok := getValueWithChainedKeys(mockValuerWithParent{
+			value: map[string]any{
+				"bar": "baz",
+			},
+			ok: true,
+		}, []string{"foo", "bar"})
+		assert.True(t, ok)
+		assert.Equal(t, "baz", v)
+	})
+
+	t.Run("two keys not found", func(t *testing.T) {
+		_, ok := getValueWithChainedKeys(mockValuerWithParent{
+			value: "bar",
+			ok:    false,
+		}, []string{"foo", "bar"})
+		assert.False(t, ok)
+	})
+
+	t.Run("two keys type mismatch", func(t *testing.T) {
+		_, ok := getValueWithChainedKeys(mockValuerWithParent{
+			value: "bar",
+			ok:    true,
+		}, []string{"foo", "bar"})
+		assert.False(t, ok)
+	})
 }
 
 func BenchmarkDefaultValue(b *testing.B) {
@@ -4521,3 +4590,17 @@ func BenchmarkUnmarshal(b *testing.B) {
 		UnmarshalKey(data, &an)
 	}
 }
+
+type mockValuerWithParent struct {
+	parent valuerWithParent
+	value  any
+	ok     bool
+}
+
+func (m mockValuerWithParent) Value(key string) (any, bool) {
+	return m.value, m.ok
+}
+
+func (m mockValuerWithParent) Parent() valuerWithParent {
+	return m.parent
+}

+ 2 - 5
core/proc/env_test.go

@@ -1,7 +1,6 @@
 package proc
 
 import (
-	"os"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
@@ -21,13 +20,11 @@ func TestEnvInt(t *testing.T) {
 	val, ok := EnvInt("any")
 	assert.Equal(t, 0, val)
 	assert.False(t, ok)
-	err := os.Setenv("anyInt", "10")
-	assert.Nil(t, err)
+	t.Setenv("anyInt", "10")
 	val, ok = EnvInt("anyInt")
 	assert.Equal(t, 10, val)
 	assert.True(t, ok)
-	err = os.Setenv("anyString", "a")
-	assert.Nil(t, err)
+	t.Setenv("anyString", "a")
 	val, ok = EnvInt("anyString")
 	assert.Equal(t, 0, val)
 	assert.False(t, ok)

+ 1 - 3
core/stat/alert_test.go

@@ -3,7 +3,6 @@
 package stat
 
 import (
-	"os"
 	"strconv"
 	"sync/atomic"
 	"testing"
@@ -12,8 +11,7 @@ import (
 )
 
 func TestReport(t *testing.T) {
-	os.Setenv(clusterNameKey, "test-cluster")
-	defer os.Unsetenv(clusterNameKey)
+	t.Setenv(clusterNameKey, "test-cluster")
 
 	var count int32
 	SetReporter(func(s string) {