Bläddra i källkod

feat: support env tag in config (#2577)

* feat: support env tag in config

* chore: add more tests

* chore: add more tests, add stringx.Join

* fix: test fail

* chore: remove print code

* chore: rename variable
Kevin Wan 2 år sedan
förälder
incheckning
69068cdaf0

+ 1 - 2
core/hash/consistenthash.go

@@ -7,7 +7,6 @@ import (
 	"sync"
 	"sync"
 
 
 	"github.com/zeromicro/go-zero/core/lang"
 	"github.com/zeromicro/go-zero/core/lang"
-	"github.com/zeromicro/go-zero/core/mapping"
 )
 )
 
 
 const (
 const (
@@ -183,5 +182,5 @@ func innerRepr(node interface{}) string {
 }
 }
 
 
 func repr(node interface{}) string {
 func repr(node interface{}) string {
-	return mapping.Repr(node)
+	return lang.Repr(node)
 }
 }

+ 67 - 0
core/lang/lang.go

@@ -1,5 +1,11 @@
 package lang
 package lang
 
 
+import (
+	"fmt"
+	"reflect"
+	"strconv"
+)
+
 // Placeholder is a placeholder object that can be used globally.
 // Placeholder is a placeholder object that can be used globally.
 var Placeholder PlaceholderType
 var Placeholder PlaceholderType
 
 
@@ -9,3 +15,64 @@ type (
 	// PlaceholderType represents a placeholder type.
 	// PlaceholderType represents a placeholder type.
 	PlaceholderType = struct{}
 	PlaceholderType = struct{}
 )
 )
+
+// Repr returns the string representation of v.
+func Repr(v interface{}) string {
+	if v == nil {
+		return ""
+	}
+
+	// if func (v *Type) String() string, we can't use Elem()
+	switch vt := v.(type) {
+	case fmt.Stringer:
+		return vt.String()
+	}
+
+	val := reflect.ValueOf(v)
+	if val.Kind() == reflect.Ptr && !val.IsNil() {
+		val = val.Elem()
+	}
+
+	return reprOfValue(val)
+}
+
+func reprOfValue(val reflect.Value) string {
+	switch vt := val.Interface().(type) {
+	case bool:
+		return strconv.FormatBool(vt)
+	case error:
+		return vt.Error()
+	case float32:
+		return strconv.FormatFloat(float64(vt), 'f', -1, 32)
+	case float64:
+		return strconv.FormatFloat(vt, 'f', -1, 64)
+	case fmt.Stringer:
+		return vt.String()
+	case int:
+		return strconv.Itoa(vt)
+	case int8:
+		return strconv.Itoa(int(vt))
+	case int16:
+		return strconv.Itoa(int(vt))
+	case int32:
+		return strconv.Itoa(int(vt))
+	case int64:
+		return strconv.FormatInt(vt, 10)
+	case string:
+		return vt
+	case uint:
+		return strconv.FormatUint(uint64(vt), 10)
+	case uint8:
+		return strconv.FormatUint(uint64(vt), 10)
+	case uint16:
+		return strconv.FormatUint(uint64(vt), 10)
+	case uint32:
+		return strconv.FormatUint(uint64(vt), 10)
+	case uint64:
+		return strconv.FormatUint(vt, 10)
+	case []byte:
+		return string(vt)
+	default:
+		return fmt.Sprint(val.Interface())
+	}
+}

+ 131 - 0
core/lang/lang_test.go

@@ -0,0 +1,131 @@
+package lang
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestRepr(t *testing.T) {
+	var (
+		f32 float32 = 1.1
+		f64         = 2.2
+		i8  int8    = 1
+		i16 int16   = 2
+		i32 int32   = 3
+		i64 int64   = 4
+		u8  uint8   = 5
+		u16 uint16  = 6
+		u32 uint32  = 7
+		u64 uint64  = 8
+	)
+	tests := []struct {
+		v      interface{}
+		expect string
+	}{
+		{
+			nil,
+			"",
+		},
+		{
+			mockStringable{},
+			"mocked",
+		},
+		{
+			new(mockStringable),
+			"mocked",
+		},
+		{
+			newMockPtr(),
+			"mockptr",
+		},
+		{
+			&mockOpacity{
+				val: 1,
+			},
+			"{1}",
+		},
+		{
+			true,
+			"true",
+		},
+		{
+			false,
+			"false",
+		},
+		{
+			f32,
+			"1.1",
+		},
+		{
+			f64,
+			"2.2",
+		},
+		{
+			i8,
+			"1",
+		},
+		{
+			i16,
+			"2",
+		},
+		{
+			i32,
+			"3",
+		},
+		{
+			i64,
+			"4",
+		},
+		{
+			u8,
+			"5",
+		},
+		{
+			u16,
+			"6",
+		},
+		{
+			u32,
+			"7",
+		},
+		{
+			u64,
+			"8",
+		},
+		{
+			[]byte(`abcd`),
+			"abcd",
+		},
+		{
+			mockOpacity{val: 1},
+			"{1}",
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.expect, func(t *testing.T) {
+			assert.Equal(t, test.expect, Repr(test.v))
+		})
+	}
+}
+
+type mockStringable struct{}
+
+func (m mockStringable) String() string {
+	return "mocked"
+}
+
+type mockPtr struct{}
+
+func newMockPtr() *mockPtr {
+	return new(mockPtr)
+}
+
+func (m *mockPtr) String() string {
+	return "mockptr"
+}
+
+type mockOpacity struct {
+	val int
+}

+ 2 - 0
core/mapping/fieldoptions.go

@@ -13,6 +13,7 @@ type (
 		Optional   bool
 		Optional   bool
 		Options    []string
 		Options    []string
 		Default    string
 		Default    string
+		EnvVar     string
 		Range      *numberRange
 		Range      *numberRange
 	}
 	}
 
 
@@ -106,5 +107,6 @@ func (o *fieldOptions) toOptionsWithContext(key string, m Valuer, fullName strin
 		Optional:   optional,
 		Optional:   optional,
 		Options:    o.Options,
 		Options:    o.Options,
 		Default:    o.Default,
 		Default:    o.Default,
+		EnvVar:     o.EnvVar,
 	}, nil
 	}, nil
 }
 }

+ 27 - 2
core/mapping/unmarshaler.go

@@ -12,6 +12,7 @@ import (
 
 
 	"github.com/zeromicro/go-zero/core/jsonx"
 	"github.com/zeromicro/go-zero/core/jsonx"
 	"github.com/zeromicro/go-zero/core/lang"
 	"github.com/zeromicro/go-zero/core/lang"
+	"github.com/zeromicro/go-zero/core/proc"
 	"github.com/zeromicro/go-zero/core/stringx"
 	"github.com/zeromicro/go-zero/core/stringx"
 )
 )
 
 
@@ -92,8 +93,7 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v interface{}, f
 	rve := rv.Elem()
 	rve := rv.Elem()
 	numFields := rte.NumField()
 	numFields := rte.NumField()
 	for i := 0; i < numFields; i++ {
 	for i := 0; i < numFields; i++ {
-		field := rte.Field(i)
-		if err := u.processField(field, rve.Field(i), m, fullName); err != nil {
+		if err := u.processField(rte.Field(i), rve.Field(i), m, fullName); err != nil {
 			return err
 			return err
 		}
 		}
 	}
 	}
@@ -338,6 +338,24 @@ func (u *Unmarshaler) processFieldTextUnmarshaler(field reflect.StructField, val
 	return false, nil
 	return false, nil
 }
 }
 
 
+func (u *Unmarshaler) processFieldWithEnvValue(field reflect.StructField, value reflect.Value,
+	envVal string, opts *fieldOptionsWithContext, fullName string) error {
+	fieldKind := field.Type.Kind()
+	switch fieldKind {
+	case durationType.Kind():
+		if err := fillDurationValue(fieldKind, value, envVal); err != nil {
+			return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
+		}
+
+		return nil
+	case reflect.String:
+		value.SetString(envVal)
+		return nil
+	default:
+		return u.processFieldPrimitiveWithJSONNumber(field, value, json.Number(envVal), opts, fullName)
+	}
+}
+
 func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect.Value,
 func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect.Value,
 	m valuerWithParent, fullName string) error {
 	m valuerWithParent, fullName string) error {
 	key, opts, err := u.parseOptionsWithContext(field, m, fullName)
 	key, opts, err := u.parseOptionsWithContext(field, m, fullName)
@@ -346,6 +364,13 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
 	}
 	}
 
 
 	fullName = join(fullName, key)
 	fullName = join(fullName, key)
+	if opts != nil && len(opts.EnvVar) > 0 {
+		envVal := proc.Env(opts.EnvVar)
+		if len(envVal) > 0 {
+			return u.processFieldWithEnvValue(field, value, envVal, opts, fullName)
+		}
+	}
+
 	canonicalKey := key
 	canonicalKey := key
 	if u.opts.canonicalKey != nil {
 	if u.opts.canonicalKey != nil {
 		canonicalKey = u.opts.canonicalKey(key)
 		canonicalKey = u.opts.canonicalKey(key)

+ 124 - 0
core/mapping/unmarshaler_test.go

@@ -3,6 +3,7 @@ package mapping
 import (
 import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
+	"os"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
@@ -3089,6 +3090,129 @@ func TestUnmarshalValuer(t *testing.T) {
 	assert.NotNil(t, err)
 	assert.NotNil(t, err)
 }
 }
 
 
+func TestUnmarshal_EnvString(t *testing.T) {
+	type Value struct {
+		Name string `key:"name,env=TEST_NAME_STRING"`
+	}
+
+	const (
+		envName = "TEST_NAME_STRING"
+		envVal  = "this is a name"
+	)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
+
+	var v Value
+	assert.NoError(t, UnmarshalKey(emptyMap, &v))
+	assert.Equal(t, envVal, v.Name)
+}
+
+func TestUnmarshal_EnvStringOverwrite(t *testing.T) {
+	type Value struct {
+		Name string `key:"name,env=TEST_NAME_STRING"`
+	}
+
+	const (
+		envName = "TEST_NAME_STRING"
+		envVal  = "this is a name"
+	)
+	os.Setenv(envName, envVal)
+	defer os.Unsetenv(envName)
+
+	var v Value
+	assert.NoError(t, UnmarshalKey(map[string]interface{}{
+		"name": "local value",
+	}, &v))
+	assert.Equal(t, envVal, v.Name)
+}
+
+func TestUnmarshal_EnvInt(t *testing.T) {
+	type Value struct {
+		Age int `key:"age,env=TEST_NAME_INT"`
+	}
+
+	const envName = "TEST_NAME_INT"
+	os.Setenv(envName, "123")
+	defer os.Unsetenv(envName)
+
+	var v Value
+	assert.NoError(t, UnmarshalKey(emptyMap, &v))
+	assert.Equal(t, 123, v.Age)
+}
+
+func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
+	type Value struct {
+		Age int `key:"age,env=TEST_NAME_INT"`
+	}
+
+	const envName = "TEST_NAME_INT"
+	os.Setenv(envName, "123")
+	defer os.Unsetenv(envName)
+
+	var v Value
+	assert.NoError(t, UnmarshalKey(map[string]interface{}{
+		"age": 18,
+	}, &v))
+	assert.Equal(t, 123, v.Age)
+}
+
+func TestUnmarshal_EnvFloat(t *testing.T) {
+	type Value struct {
+		Age float32 `key:"name,env=TEST_NAME_FLOAT"`
+	}
+
+	const envName = "TEST_NAME_FLOAT"
+	os.Setenv(envName, "123.45")
+	defer os.Unsetenv(envName)
+
+	var v Value
+	assert.NoError(t, UnmarshalKey(emptyMap, &v))
+	assert.Equal(t, float32(123.45), v.Age)
+}
+
+func TestUnmarshal_EnvFloatOverwrite(t *testing.T) {
+	type Value struct {
+		Age float32 `key:"age,env=TEST_NAME_FLOAT"`
+	}
+
+	const envName = "TEST_NAME_FLOAT"
+	os.Setenv(envName, "123.45")
+	defer os.Unsetenv(envName)
+
+	var v Value
+	assert.NoError(t, UnmarshalKey(map[string]interface{}{
+		"age": 18.5,
+	}, &v))
+	assert.Equal(t, float32(123.45), v.Age)
+}
+
+func TestUnmarshal_EnvDuration(t *testing.T) {
+	type Value struct {
+		Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
+	}
+
+	const envName = "TEST_NAME_DURATION"
+	os.Setenv(envName, "1s")
+	defer os.Unsetenv(envName)
+
+	var v Value
+	assert.NoError(t, UnmarshalKey(emptyMap, &v))
+	assert.Equal(t, time.Second, v.Duration)
+}
+
+func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
+	type Value struct {
+		Duration time.Duration `key:"duration,env=TEST_NAME_BAD_DURATION"`
+	}
+
+	const envName = "TEST_NAME_BAD_DURATION"
+	os.Setenv(envName, "bad")
+	defer os.Unsetenv(envName)
+
+	var v Value
+	assert.NotNil(t, UnmarshalKey(emptyMap, &v))
+}
+
 func BenchmarkUnmarshalString(b *testing.B) {
 func BenchmarkUnmarshalString(b *testing.B) {
 	type inner struct {
 	type inner struct {
 		Value string `key:"value"`
 		Value string `key:"value"`

+ 31 - 69
core/mapping/utils.go

@@ -10,11 +10,13 @@ import (
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 
 
+	"github.com/zeromicro/go-zero/core/lang"
 	"github.com/zeromicro/go-zero/core/stringx"
 	"github.com/zeromicro/go-zero/core/stringx"
 )
 )
 
 
 const (
 const (
 	defaultOption      = "default"
 	defaultOption      = "default"
+	envOption          = "env"
 	inheritOption      = "inherit"
 	inheritOption      = "inherit"
 	stringOption       = "string"
 	stringOption       = "string"
 	optionalOption     = "optional"
 	optionalOption     = "optional"
@@ -63,22 +65,7 @@ func Deref(t reflect.Type) reflect.Type {
 
 
 // Repr returns the string representation of v.
 // Repr returns the string representation of v.
 func Repr(v interface{}) string {
 func Repr(v interface{}) string {
-	if v == nil {
-		return ""
-	}
-
-	// if func (v *Type) String() string, we can't use Elem()
-	switch vt := v.(type) {
-	case fmt.Stringer:
-		return vt.String()
-	}
-
-	val := reflect.ValueOf(v)
-	if val.Kind() == reflect.Ptr && !val.IsNil() {
-		val = val.Elem()
-	}
-
-	return reprOfValue(val)
+	return lang.Repr(v)
 }
 }
 
 
 // ValidatePtr validates v if it's a valid pointer.
 // ValidatePtr validates v if it's a valid pointer.
@@ -354,26 +341,33 @@ func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
 	case option == optionalOption:
 	case option == optionalOption:
 		fieldOpts.Optional = true
 		fieldOpts.Optional = true
 	case strings.HasPrefix(option, optionsOption):
 	case strings.HasPrefix(option, optionsOption):
-		segs := strings.Split(option, equalToken)
-		if len(segs) != 2 {
-			return fmt.Errorf("field %s has wrong options", fieldName)
+		val, err := parseProperty(fieldName, optionsOption, option)
+		if err != nil {
+			return err
 		}
 		}
 
 
-		fieldOpts.Options = parseOptions(segs[1])
+		fieldOpts.Options = parseOptions(val)
 	case strings.HasPrefix(option, defaultOption):
 	case strings.HasPrefix(option, defaultOption):
-		segs := strings.Split(option, equalToken)
-		if len(segs) != 2 {
-			return fmt.Errorf("field %s has wrong default option", fieldName)
+		val, err := parseProperty(fieldName, defaultOption, option)
+		if err != nil {
+			return err
+		}
+
+		fieldOpts.Default = val
+	case strings.HasPrefix(option, envOption):
+		val, err := parseProperty(fieldName, envOption, option)
+		if err != nil {
+			return err
 		}
 		}
 
 
-		fieldOpts.Default = strings.TrimSpace(segs[1])
+		fieldOpts.EnvVar = val
 	case strings.HasPrefix(option, rangeOption):
 	case strings.HasPrefix(option, rangeOption):
-		segs := strings.Split(option, equalToken)
-		if len(segs) != 2 {
-			return fmt.Errorf("field %s has wrong range", fieldName)
+		val, err := parseProperty(fieldName, rangeOption, option)
+		if err != nil {
+			return err
 		}
 		}
 
 
-		nr, err := parseNumberRange(segs[1])
+		nr, err := parseNumberRange(val)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -398,6 +392,15 @@ func parseOptions(val string) []string {
 	return strings.Split(val, optionSeparator)
 	return strings.Split(val, optionSeparator)
 }
 }
 
 
+func parseProperty(field, tag, val string) (string, error) {
+	segs := strings.Split(val, equalToken)
+	if len(segs) != 2 {
+		return "", fmt.Errorf("field %s has wrong %s", field, tag)
+	}
+
+	return strings.TrimSpace(segs[1]), nil
+}
+
 func parseSegments(val string) []string {
 func parseSegments(val string) []string {
 	var segments []string
 	var segments []string
 	var escaped, grouped bool
 	var escaped, grouped bool
@@ -447,47 +450,6 @@ func parseSegments(val string) []string {
 	return segments
 	return segments
 }
 }
 
 
-func reprOfValue(val reflect.Value) string {
-	switch vt := val.Interface().(type) {
-	case bool:
-		return strconv.FormatBool(vt)
-	case error:
-		return vt.Error()
-	case float32:
-		return strconv.FormatFloat(float64(vt), 'f', -1, 32)
-	case float64:
-		return strconv.FormatFloat(vt, 'f', -1, 64)
-	case fmt.Stringer:
-		return vt.String()
-	case int:
-		return strconv.Itoa(vt)
-	case int8:
-		return strconv.Itoa(int(vt))
-	case int16:
-		return strconv.Itoa(int(vt))
-	case int32:
-		return strconv.Itoa(int(vt))
-	case int64:
-		return strconv.FormatInt(vt, 10)
-	case string:
-		return vt
-	case uint:
-		return strconv.FormatUint(uint64(vt), 10)
-	case uint8:
-		return strconv.FormatUint(uint64(vt), 10)
-	case uint16:
-		return strconv.FormatUint(uint64(vt), 10)
-	case uint32:
-		return strconv.FormatUint(uint64(vt), 10)
-	case uint64:
-		return strconv.FormatUint(vt, 10)
-	case []byte:
-		return string(vt)
-	default:
-		return fmt.Sprint(val.Interface())
-	}
-}
-
 func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interface{}) error {
 func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interface{}) error {
 	switch kind {
 	switch kind {
 	case reflect.Bool:
 	case reflect.Bool:

+ 0 - 124
core/mapping/utils_test.go

@@ -296,127 +296,3 @@ func TestSetValueFormatErrors(t *testing.T) {
 		})
 		})
 	}
 	}
 }
 }
-
-func TestRepr(t *testing.T) {
-	var (
-		f32 float32 = 1.1
-		f64         = 2.2
-		i8  int8    = 1
-		i16 int16   = 2
-		i32 int32   = 3
-		i64 int64   = 4
-		u8  uint8   = 5
-		u16 uint16  = 6
-		u32 uint32  = 7
-		u64 uint64  = 8
-	)
-	tests := []struct {
-		v      interface{}
-		expect string
-	}{
-		{
-			nil,
-			"",
-		},
-		{
-			mockStringable{},
-			"mocked",
-		},
-		{
-			new(mockStringable),
-			"mocked",
-		},
-		{
-			newMockPtr(),
-			"mockptr",
-		},
-		{
-			&mockOpacity{
-				val: 1,
-			},
-			"{1}",
-		},
-		{
-			true,
-			"true",
-		},
-		{
-			false,
-			"false",
-		},
-		{
-			f32,
-			"1.1",
-		},
-		{
-			f64,
-			"2.2",
-		},
-		{
-			i8,
-			"1",
-		},
-		{
-			i16,
-			"2",
-		},
-		{
-			i32,
-			"3",
-		},
-		{
-			i64,
-			"4",
-		},
-		{
-			u8,
-			"5",
-		},
-		{
-			u16,
-			"6",
-		},
-		{
-			u32,
-			"7",
-		},
-		{
-			u64,
-			"8",
-		},
-		{
-			[]byte(`abcd`),
-			"abcd",
-		},
-		{
-			mockOpacity{val: 1},
-			"{1}",
-		},
-	}
-
-	for _, test := range tests {
-		t.Run(test.expect, func(t *testing.T) {
-			assert.Equal(t, test.expect, Repr(test.v))
-		})
-	}
-}
-
-type mockStringable struct{}
-
-func (m mockStringable) String() string {
-	return "mocked"
-}
-
-type mockPtr struct{}
-
-func newMockPtr() *mockPtr {
-	return new(mockPtr)
-}
-
-func (m *mockPtr) String() string {
-	return "mockptr"
-}
-
-type mockOpacity struct {
-	val int
-}

+ 27 - 0
core/stringx/strings.go

@@ -69,6 +69,33 @@ func HasEmpty(args ...string) bool {
 	return false
 	return false
 }
 }
 
 
+// Join joins any number of elements into a single string, separating them with given sep.
+// Empty elements are ignored. However, if the argument list is empty or all its elements are empty,
+// Join returns an empty string.
+func Join(sep byte, elem ...string) string {
+	var size int
+	for _, e := range elem {
+		size += len(e)
+	}
+	if size == 0 {
+		return ""
+	}
+
+	buf := make([]byte, 0, size+len(elem)-1)
+	for _, e := range elem {
+		if len(e) == 0 {
+			continue
+		}
+
+		if len(buf) > 0 {
+			buf = append(buf, sep)
+		}
+		buf = append(buf, e...)
+	}
+
+	return string(buf)
+}
+
 // NotEmpty checks if all strings are not empty in args.
 // NotEmpty checks if all strings are not empty in args.
 func NotEmpty(args ...string) bool {
 func NotEmpty(args ...string) bool {
 	return !HasEmpty(args...)
 	return !HasEmpty(args...)

+ 36 - 0
core/stringx/strings_test.go

@@ -147,6 +147,42 @@ func TestFirstN(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestJoin(t *testing.T) {
+	tests := []struct {
+		name   string
+		input  []string
+		expect string
+	}{
+		{
+			name:   "all blanks",
+			input:  []string{"", ""},
+			expect: "",
+		},
+		{
+			name:   "two values",
+			input:  []string{"012", "abc"},
+			expect: "012.abc",
+		},
+		{
+			name:   "last blank",
+			input:  []string{"abc", ""},
+			expect: "abc",
+		},
+		{
+			name:   "first blank",
+			input:  []string{"", "abc"},
+			expect: "abc",
+		},
+	}
+
+	for _, test := range tests {
+		test := test
+		t.Run(test.name, func(t *testing.T) {
+			assert.Equal(t, test.expect, Join('.', test.input...))
+		})
+	}
+}
+
 func TestRemove(t *testing.T) {
 func TestRemove(t *testing.T) {
 	cases := []struct {
 	cases := []struct {
 		input  []string
 		input  []string