Skip to content

Commit f44c13e

Browse files
authored
Cleanup and more test cases (#425)
1 parent 68ac8f5 commit f44c13e

File tree

3 files changed

+182
-91
lines changed

3 files changed

+182
-91
lines changed

conf/conf.go

Lines changed: 57 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package conf
22

33
import (
44
"encoding"
5-
"errors"
65
"fmt"
76
"log/slog"
87
"os"
@@ -11,76 +10,58 @@ import (
1110
"strings"
1211
)
1312

14-
type errUnsupportedType struct {
15-
fieldType reflect.Type
16-
}
17-
18-
func (e errUnsupportedType) Error() string {
19-
return fmt.Sprintf("unsupported field type: %s", e.fieldType)
20-
}
13+
var textMarshalerType = reflect.TypeFor[encoding.TextMarshaler]()
2114

2215
// Bind initializes the supplied object based on assoiciated struct tags
2316
func Bind(ptr any) error {
2417
val := reflect.ValueOf(ptr)
2518
if val.Kind() != reflect.Pointer {
26-
return errors.New("bind requires pointer types")
19+
return errPointerRequired
2720
}
2821

2922
val = val.Elem()
3023
if val.Kind() != reflect.Struct {
31-
return errors.New("bind requires struct types")
24+
return errStructRequired
3225
}
3326

3427
return bindStruct(val)
3528
}
3629

3730
func bindStruct(objVal reflect.Value) error {
38-
setValue := func(field reflect.StructField, fieldVal reflect.Value) error {
39-
if err := setToDefault(field, fieldVal); err != nil {
40-
return err
41-
}
42-
setFromEnv(field, fieldVal)
43-
44-
return nil
45-
}
46-
4731
for i := 0; i < objVal.NumField(); i++ {
4832
field := objVal.Type().Field(i)
4933
if !field.IsExported() {
5034
continue
5135
}
5236

53-
fieldVal := objVal.Field(i)
37+
fieldVal := resolvePointers(objVal.Field(i))
5438

55-
// Walk through any pointer layers
56-
for fieldVal.Type().Kind() == reflect.Pointer {
57-
if fieldVal.IsNil() {
58-
ptrType := fieldVal.Type().Elem()
59-
fieldVal.Set(reflect.New(ptrType))
60-
}
61-
fieldVal = fieldVal.Elem()
62-
}
63-
64-
var err error
65-
// If this is a struct, we need to recurse
66-
if fieldVal.Kind() == reflect.Struct {
67-
// Unless this is a TextUnmarshaler
68-
if _, ok := fieldVal.Addr().Interface().(encoding.TextUnmarshaler); ok {
69-
err = setValue(field, fieldVal)
70-
} else {
71-
err = bindStruct(fieldVal)
39+
// If this is a struct, we need to recurse unless it's a TextUnmarshaler
40+
if fieldVal.Kind() == reflect.Struct && !fieldVal.Type().AssignableTo(textMarshalerType) {
41+
if err := bindStruct(fieldVal); err != nil {
42+
return err
7243
}
7344
} else {
74-
err = setValue(field, fieldVal)
75-
}
76-
if err != nil {
77-
return err
45+
if err := setToDefault(field, fieldVal); err != nil {
46+
return err
47+
}
48+
setFromEnv(field, fieldVal)
7849
}
7950
}
8051

8152
return nil
8253
}
8354

55+
func resolvePointers(val reflect.Value) reflect.Value {
56+
for val.Type().Kind() == reflect.Pointer {
57+
if val.IsNil() {
58+
val.Set(reflect.New(val.Type().Elem()))
59+
}
60+
val = val.Elem()
61+
}
62+
return val
63+
}
64+
8465
func setToDefault(field reflect.StructField, val reflect.Value) error {
8566
if defaultStr, ok := field.Tag.Lookup("default"); ok {
8667
if err := set(val, defaultStr); err != nil {
@@ -124,62 +105,60 @@ func set(val reflect.Value, str string) error {
124105
if unmarshaler, ok := val.Addr().Interface().(encoding.TextUnmarshaler); ok {
125106
return unmarshaler.UnmarshalText([]byte(str))
126107
}
127-
return errUnsupportedType{val.Type()}
108+
return &errUnsupportedType{val.Type()}
128109

129110
case reflect.String:
130111
val.SetString(str)
112+
return nil
131113

132114
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
133-
typed, err := strconv.ParseInt(str, 10, val.Type().Bits())
134-
if err != nil {
135-
return err
136-
}
137-
val.SetInt(typed)
115+
return convertAndSet(str, func(str string) (int64, error) {
116+
return strconv.ParseInt(str, 0, val.Type().Bits())
117+
}, val.SetInt)
138118

139119
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
140-
typed, err := strconv.ParseUint(str, 10, val.Type().Bits())
141-
if err != nil {
142-
return err
143-
}
144-
val.SetUint(typed)
120+
return convertAndSet(str, func(str string) (uint64, error) {
121+
return strconv.ParseUint(str, 0, val.Type().Bits())
122+
}, val.SetUint)
145123

146124
case reflect.Float32, reflect.Float64:
147-
typed, err := strconv.ParseFloat(str, val.Type().Bits())
148-
if err != nil {
149-
return err
150-
}
151-
val.SetFloat(typed)
125+
return convertAndSet(str, func(str string) (float64, error) {
126+
return strconv.ParseFloat(str, val.Type().Bits())
127+
}, val.SetFloat)
152128

153129
case reflect.Complex64, reflect.Complex128:
154-
typed, err := strconv.ParseComplex(str, val.Type().Bits())
155-
if err != nil {
156-
return err
157-
}
158-
val.SetComplex(typed)
130+
return convertAndSet(str, func(str string) (complex128, error) {
131+
return strconv.ParseComplex(str, val.Type().Bits())
132+
}, val.SetComplex)
159133

160134
case reflect.Bool:
161-
typed, err := strconv.ParseBool(str)
162-
if err != nil {
163-
return err
164-
}
165-
val.SetBool(typed)
135+
return convertAndSet(str, strconv.ParseBool, val.SetBool)
166136

167137
case reflect.Array, reflect.Slice:
168-
elementType := val.Type().Elem()
169-
segments := strings.Split(str, ",")
170-
newVal := reflect.MakeSlice(val.Type(), 0, len(segments))
171-
for _, segment := range segments {
172-
element := reflect.New(elementType).Elem()
173-
if err := set(element, strings.TrimSpace(segment)); err != nil {
174-
return err
138+
return convertAndSet(str, func(str string) (reflect.Value, error) {
139+
valType := val.Type()
140+
segments := strings.Split(str, ",")
141+
newVal := reflect.MakeSlice(valType, 0, len(segments))
142+
for _, segment := range segments {
143+
element := reflect.New(valType.Elem()).Elem()
144+
if err := set(element, strings.TrimSpace(segment)); err != nil {
145+
return reflect.Zero(valType), err
146+
}
147+
newVal = reflect.Append(newVal, element)
175148
}
176-
newVal = reflect.Append(newVal, element)
177-
}
178-
val.Set(newVal)
149+
return newVal, nil
150+
}, val.Set)
179151

180152
default:
181-
return errUnsupportedType{val.Type()}
153+
return &errUnsupportedType{val.Type()}
182154
}
155+
}
183156

157+
func convertAndSet[T any](str string, converter func(str string) (T, error), setter func(val T)) error {
158+
typed, err := converter(str)
159+
if err != nil {
160+
return err
161+
}
162+
setter(typed)
184163
return nil
185164
}

0 commit comments

Comments
 (0)