|
1 | 1 | package conf
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "encoding" |
4 | 5 | "errors"
|
5 | 6 | "fmt"
|
6 | 7 | "log/slog"
|
@@ -30,26 +31,50 @@ func Bind(ptr any) error {
|
30 | 31 | return errors.New("bind requires struct types")
|
31 | 32 | }
|
32 | 33 |
|
33 |
| - return bindValue(val) |
| 34 | + return bindStruct(val) |
34 | 35 | }
|
35 | 36 |
|
36 |
| -func bindValue(objVal reflect.Value) error { |
| 37 | +func bindStruct(objVal reflect.Value) error { |
| 38 | + setValue := func(field reflect.StructField, fieldVal reflect.Value) error { |
| 39 | + if err := setToDefault(field, fieldVal); err != nil { |
| 40 | + return err |
| 41 | + } |
| 42 | + setFromEnv(field, fieldVal) |
| 43 | + |
| 44 | + return nil |
| 45 | + } |
| 46 | + |
37 | 47 | for i := 0; i < objVal.NumField(); i++ {
|
38 | 48 | field := objVal.Type().Field(i)
|
39 | 49 | if !field.IsExported() {
|
40 | 50 | continue
|
41 | 51 | }
|
42 | 52 |
|
43 | 53 | fieldVal := objVal.Field(i)
|
| 54 | + |
| 55 | + // Walk through any pointer layers |
| 56 | + for fieldVal.Type().Kind() == reflect.Pointer { |
| 57 | + if fieldVal.IsNil() { |
| 58 | + ptrType := fieldVal.Type().Elem() |
| 59 | + fieldVal.Set(reflect.New(ptrType)) |
| 60 | + } |
| 61 | + fieldVal = fieldVal.Elem() |
| 62 | + } |
| 63 | + |
| 64 | + var err error |
| 65 | + // If this is a struct, we need to recurse |
44 | 66 | if fieldVal.Kind() == reflect.Struct {
|
45 |
| - if err := bindValue(fieldVal); err != nil { |
46 |
| - return err |
| 67 | + // Unless this is a TextUnmarshaler |
| 68 | + if _, ok := fieldVal.Addr().Interface().(encoding.TextUnmarshaler); ok { |
| 69 | + err = setValue(field, fieldVal) |
| 70 | + } else { |
| 71 | + err = bindStruct(fieldVal) |
47 | 72 | }
|
48 | 73 | } else {
|
49 |
| - if err := setToDefault(field, fieldVal); err != nil { |
50 |
| - return err |
51 |
| - } |
52 |
| - setFromEnv(field, fieldVal) |
| 74 | + err = setValue(field, fieldVal) |
| 75 | + } |
| 76 | + if err != nil { |
| 77 | + return err |
53 | 78 | }
|
54 | 79 | }
|
55 | 80 |
|
@@ -95,6 +120,12 @@ func setFromEnv(field reflect.StructField, val reflect.Value) {
|
95 | 120 |
|
96 | 121 | func set(val reflect.Value, str string) error {
|
97 | 122 | switch val.Type().Kind() {
|
| 123 | + case reflect.Struct: |
| 124 | + if unmarshaler, ok := val.Addr().Interface().(encoding.TextUnmarshaler); ok { |
| 125 | + return unmarshaler.UnmarshalText([]byte(str)) |
| 126 | + } |
| 127 | + return errUnsupportedType{val.Type()} |
| 128 | + |
98 | 129 | case reflect.String:
|
99 | 130 | val.SetString(str)
|
100 | 131 |
|
@@ -146,13 +177,6 @@ func set(val reflect.Value, str string) error {
|
146 | 177 | }
|
147 | 178 | val.Set(newVal)
|
148 | 179 |
|
149 |
| - case reflect.Pointer: |
150 |
| - ptrType := val.Type().Elem() |
151 |
| - if val.IsNil() { |
152 |
| - val.Set(reflect.New(ptrType)) |
153 |
| - } |
154 |
| - return set(val.Elem(), str) |
155 |
| - |
156 | 180 | default:
|
157 | 181 | return errUnsupportedType{val.Type()}
|
158 | 182 | }
|
|
0 commit comments