Skip to content

Commit 68ac8f5

Browse files
authored
Add support for TextUnmarshaler (#424)
* Add support for TextUnmarshaler * Fix TextUnmarshaler handling * Remove unused global * Handle pointers at a different layer
1 parent 2fb8fbf commit 68ac8f5

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

conf/conf.go

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package conf
22

33
import (
4+
"encoding"
45
"errors"
56
"fmt"
67
"log/slog"
@@ -30,26 +31,50 @@ func Bind(ptr any) error {
3031
return errors.New("bind requires struct types")
3132
}
3233

33-
return bindValue(val)
34+
return bindStruct(val)
3435
}
3536

36-
func bindValue(objVal reflect.Value) error {
37+
func bindStruct(objVal reflect.Value) error {
38+
setValue := func(field reflect.StructField, fieldVal reflect.Value) error {
39+
if err := setToDefault(field, fieldVal); err != nil {
40+
return err
41+
}
42+
setFromEnv(field, fieldVal)
43+
44+
return nil
45+
}
46+
3747
for i := 0; i < objVal.NumField(); i++ {
3848
field := objVal.Type().Field(i)
3949
if !field.IsExported() {
4050
continue
4151
}
4252

4353
fieldVal := objVal.Field(i)
54+
55+
// Walk through any pointer layers
56+
for fieldVal.Type().Kind() == reflect.Pointer {
57+
if fieldVal.IsNil() {
58+
ptrType := fieldVal.Type().Elem()
59+
fieldVal.Set(reflect.New(ptrType))
60+
}
61+
fieldVal = fieldVal.Elem()
62+
}
63+
64+
var err error
65+
// If this is a struct, we need to recurse
4466
if fieldVal.Kind() == reflect.Struct {
45-
if err := bindValue(fieldVal); err != nil {
46-
return err
67+
// Unless this is a TextUnmarshaler
68+
if _, ok := fieldVal.Addr().Interface().(encoding.TextUnmarshaler); ok {
69+
err = setValue(field, fieldVal)
70+
} else {
71+
err = bindStruct(fieldVal)
4772
}
4873
} else {
49-
if err := setToDefault(field, fieldVal); err != nil {
50-
return err
51-
}
52-
setFromEnv(field, fieldVal)
74+
err = setValue(field, fieldVal)
75+
}
76+
if err != nil {
77+
return err
5378
}
5479
}
5580

@@ -95,6 +120,12 @@ func setFromEnv(field reflect.StructField, val reflect.Value) {
95120

96121
func set(val reflect.Value, str string) error {
97122
switch val.Type().Kind() {
123+
case reflect.Struct:
124+
if unmarshaler, ok := val.Addr().Interface().(encoding.TextUnmarshaler); ok {
125+
return unmarshaler.UnmarshalText([]byte(str))
126+
}
127+
return errUnsupportedType{val.Type()}
128+
98129
case reflect.String:
99130
val.SetString(str)
100131

@@ -146,13 +177,6 @@ func set(val reflect.Value, str string) error {
146177
}
147178
val.Set(newVal)
148179

149-
case reflect.Pointer:
150-
ptrType := val.Type().Elem()
151-
if val.IsNil() {
152-
val.Set(reflect.New(ptrType))
153-
}
154-
return set(val.Elem(), str)
155-
156180
default:
157181
return errUnsupportedType{val.Type()}
158182
}

conf/conf_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package conf
33
import (
44
"reflect"
55
"testing"
6+
"time"
67

78
"github.com/chadweimer/gomp/utils"
89
)
@@ -40,6 +41,8 @@ func TestBind_Defaults(t *testing.T) {
4041
TestBoolArray []bool `default:"true,false"`
4142

4243
TestString string `default:"Hello, Tests!"`
44+
45+
TestTime time.Time `default:"2000-01-02T03:04:05Z"`
4346
}
4447
tests := []struct {
4548
name string
@@ -78,6 +81,8 @@ func TestBind_Defaults(t *testing.T) {
7881
TestBoolArray: []bool{true, false},
7982

8083
TestString: "Hello, Tests!",
84+
85+
TestTime: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC),
8186
},
8287
},
8388
}

0 commit comments

Comments
 (0)