@@ -2,7 +2,6 @@ package conf
2
2
3
3
import (
4
4
"encoding"
5
- "errors"
6
5
"fmt"
7
6
"log/slog"
8
7
"os"
@@ -11,76 +10,58 @@ import (
11
10
"strings"
12
11
)
13
12
14
- type errUnsupportedType struct {
15
- fieldType reflect.Type
16
- }
17
-
18
- func (e errUnsupportedType ) Error () string {
19
- return fmt .Sprintf ("unsupported field type: %s" , e .fieldType )
20
- }
13
+ var textMarshalerType = reflect .TypeFor [encoding.TextMarshaler ]()
21
14
22
15
// Bind initializes the supplied object based on assoiciated struct tags
23
16
func Bind (ptr any ) error {
24
17
val := reflect .ValueOf (ptr )
25
18
if val .Kind () != reflect .Pointer {
26
- return errors . New ( "bind requires pointer types" )
19
+ return errPointerRequired
27
20
}
28
21
29
22
val = val .Elem ()
30
23
if val .Kind () != reflect .Struct {
31
- return errors . New ( "bind requires struct types" )
24
+ return errStructRequired
32
25
}
33
26
34
27
return bindStruct (val )
35
28
}
36
29
37
30
func bindStruct (objVal reflect.Value ) error {
38
- setValue := func (field reflect.StructField , fieldVal reflect.Value ) error {
39
- if err := setToDefault (field , fieldVal ); err != nil {
40
- return err
41
- }
42
- setFromEnv (field , fieldVal )
43
-
44
- return nil
45
- }
46
-
47
31
for i := 0 ; i < objVal .NumField (); i ++ {
48
32
field := objVal .Type ().Field (i )
49
33
if ! field .IsExported () {
50
34
continue
51
35
}
52
36
53
- fieldVal := objVal .Field (i )
37
+ fieldVal := resolvePointers ( objVal .Field (i ) )
54
38
55
- // Walk through any pointer layers
56
- for fieldVal .Type ().Kind () == reflect .Pointer {
57
- if fieldVal .IsNil () {
58
- ptrType := fieldVal .Type ().Elem ()
59
- fieldVal .Set (reflect .New (ptrType ))
60
- }
61
- fieldVal = fieldVal .Elem ()
62
- }
63
-
64
- var err error
65
- // If this is a struct, we need to recurse
66
- if fieldVal .Kind () == reflect .Struct {
67
- // Unless this is a TextUnmarshaler
68
- if _ , ok := fieldVal .Addr ().Interface ().(encoding.TextUnmarshaler ); ok {
69
- err = setValue (field , fieldVal )
70
- } else {
71
- err = bindStruct (fieldVal )
39
+ // If this is a struct, we need to recurse unless it's a TextUnmarshaler
40
+ if fieldVal .Kind () == reflect .Struct && ! fieldVal .Type ().AssignableTo (textMarshalerType ) {
41
+ if err := bindStruct (fieldVal ); err != nil {
42
+ return err
72
43
}
73
44
} else {
74
- err = setValue (field , fieldVal )
75
- }
76
- if err != nil {
77
- return err
45
+ if err := setToDefault (field , fieldVal ); err != nil {
46
+ return err
47
+ }
48
+ setFromEnv ( field , fieldVal )
78
49
}
79
50
}
80
51
81
52
return nil
82
53
}
83
54
55
+ func resolvePointers (val reflect.Value ) reflect.Value {
56
+ for val .Type ().Kind () == reflect .Pointer {
57
+ if val .IsNil () {
58
+ val .Set (reflect .New (val .Type ().Elem ()))
59
+ }
60
+ val = val .Elem ()
61
+ }
62
+ return val
63
+ }
64
+
84
65
func setToDefault (field reflect.StructField , val reflect.Value ) error {
85
66
if defaultStr , ok := field .Tag .Lookup ("default" ); ok {
86
67
if err := set (val , defaultStr ); err != nil {
@@ -124,62 +105,60 @@ func set(val reflect.Value, str string) error {
124
105
if unmarshaler , ok := val .Addr ().Interface ().(encoding.TextUnmarshaler ); ok {
125
106
return unmarshaler .UnmarshalText ([]byte (str ))
126
107
}
127
- return errUnsupportedType {val .Type ()}
108
+ return & errUnsupportedType {val .Type ()}
128
109
129
110
case reflect .String :
130
111
val .SetString (str )
112
+ return nil
131
113
132
114
case reflect .Int , reflect .Int8 , reflect .Int16 , reflect .Int32 , reflect .Int64 :
133
- typed , err := strconv .ParseInt (str , 10 , val .Type ().Bits ())
134
- if err != nil {
135
- return err
136
- }
137
- val .SetInt (typed )
115
+ return convertAndSet (str , func (str string ) (int64 , error ) {
116
+ return strconv .ParseInt (str , 0 , val .Type ().Bits ())
117
+ }, val .SetInt )
138
118
139
119
case reflect .Uint , reflect .Uint8 , reflect .Uint16 , reflect .Uint32 , reflect .Uint64 :
140
- typed , err := strconv .ParseUint (str , 10 , val .Type ().Bits ())
141
- if err != nil {
142
- return err
143
- }
144
- val .SetUint (typed )
120
+ return convertAndSet (str , func (str string ) (uint64 , error ) {
121
+ return strconv .ParseUint (str , 0 , val .Type ().Bits ())
122
+ }, val .SetUint )
145
123
146
124
case reflect .Float32 , reflect .Float64 :
147
- typed , err := strconv .ParseFloat (str , val .Type ().Bits ())
148
- if err != nil {
149
- return err
150
- }
151
- val .SetFloat (typed )
125
+ return convertAndSet (str , func (str string ) (float64 , error ) {
126
+ return strconv .ParseFloat (str , val .Type ().Bits ())
127
+ }, val .SetFloat )
152
128
153
129
case reflect .Complex64 , reflect .Complex128 :
154
- typed , err := strconv .ParseComplex (str , val .Type ().Bits ())
155
- if err != nil {
156
- return err
157
- }
158
- val .SetComplex (typed )
130
+ return convertAndSet (str , func (str string ) (complex128 , error ) {
131
+ return strconv .ParseComplex (str , val .Type ().Bits ())
132
+ }, val .SetComplex )
159
133
160
134
case reflect .Bool :
161
- typed , err := strconv .ParseBool (str )
162
- if err != nil {
163
- return err
164
- }
165
- val .SetBool (typed )
135
+ return convertAndSet (str , strconv .ParseBool , val .SetBool )
166
136
167
137
case reflect .Array , reflect .Slice :
168
- elementType := val .Type ().Elem ()
169
- segments := strings .Split (str , "," )
170
- newVal := reflect .MakeSlice (val .Type (), 0 , len (segments ))
171
- for _ , segment := range segments {
172
- element := reflect .New (elementType ).Elem ()
173
- if err := set (element , strings .TrimSpace (segment )); err != nil {
174
- return err
138
+ return convertAndSet (str , func (str string ) (reflect.Value , error ) {
139
+ valType := val .Type ()
140
+ segments := strings .Split (str , "," )
141
+ newVal := reflect .MakeSlice (valType , 0 , len (segments ))
142
+ for _ , segment := range segments {
143
+ element := reflect .New (valType .Elem ()).Elem ()
144
+ if err := set (element , strings .TrimSpace (segment )); err != nil {
145
+ return reflect .Zero (valType ), err
146
+ }
147
+ newVal = reflect .Append (newVal , element )
175
148
}
176
- newVal = reflect .Append (newVal , element )
177
- }
178
- val .Set (newVal )
149
+ return newVal , nil
150
+ }, val .Set )
179
151
180
152
default :
181
- return errUnsupportedType {val .Type ()}
153
+ return & errUnsupportedType {val .Type ()}
182
154
}
155
+ }
183
156
157
+ func convertAndSet [T any ](str string , converter func (str string ) (T , error ), setter func (val T )) error {
158
+ typed , err := converter (str )
159
+ if err != nil {
160
+ return err
161
+ }
162
+ setter (typed )
184
163
return nil
185
164
}
0 commit comments