Skip to content

Commit b0e8c20

Browse files
emil2kkisielk
authored andcommitted
Fix slice of structs TextUnmarshaler. (#103)
Fix handling of situation where a slice of structs implments the encoding.TextUnmarshaler interface, previously it would return "invalid path" error. Includes, some minor refactoring and documentation to clarify `isUnmarshaler` output.
1 parent afe7739 commit b0e8c20

File tree

4 files changed

+122
-42
lines changed

4 files changed

+122
-42
lines changed

cache.go

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) {
6363
}
6464
// Valid field. Append index.
6565
path = append(path, field.name)
66-
if field.ss {
66+
if field.isSliceOfStructs && (!field.unmarshalerInfo.IsValid || (field.unmarshalerInfo.IsValid && field.unmarshalerInfo.IsSliceElement)) {
6767
// Parse a special case: slices of structs.
6868
// i+1 must be the slice index.
6969
//
@@ -142,7 +142,7 @@ func (c *cache) create(t reflect.Type, info *structInfo) *structInfo {
142142
c.create(ft, info)
143143
for _, fi := range info.fields[bef:len(info.fields)] {
144144
// exclude required check because duplicated to embedded field
145-
fi.required = false
145+
fi.isRequired = false
146146
}
147147
}
148148
}
@@ -162,6 +162,7 @@ func (c *cache) createField(field reflect.StructField, info *structInfo) {
162162
// First let's get the basic type.
163163
isSlice, isStruct := false, false
164164
ft := field.Type
165+
m := isTextUnmarshaler(reflect.Zero(ft))
165166
if ft.Kind() == reflect.Ptr {
166167
ft = ft.Elem()
167168
}
@@ -185,12 +186,13 @@ func (c *cache) createField(field reflect.StructField, info *structInfo) {
185186
}
186187

187188
info.fields = append(info.fields, &fieldInfo{
188-
typ: field.Type,
189-
name: field.Name,
190-
ss: isSlice && isStruct,
191-
alias: alias,
192-
anon: field.Anonymous,
193-
required: options.Contains("required"),
189+
typ: field.Type,
190+
name: field.Name,
191+
alias: alias,
192+
unmarshalerInfo: m,
193+
isSliceOfStructs: isSlice && isStruct,
194+
isAnonymous: field.Anonymous,
195+
isRequired: options.Contains("required"),
194196
})
195197
}
196198

@@ -215,12 +217,18 @@ func (i *structInfo) get(alias string) *fieldInfo {
215217
}
216218

217219
type fieldInfo struct {
218-
typ reflect.Type
219-
name string // field name in the struct.
220-
ss bool // true if this is a slice of structs.
221-
alias string
222-
anon bool // is an embedded field
223-
required bool // tag option
220+
typ reflect.Type
221+
// name is the field name in the struct.
222+
name string
223+
alias string
224+
// unmarshalerInfo contains information regarding the
225+
// encoding.TextUnmarshaler implementation of the field type.
226+
unmarshalerInfo unmarshaler
227+
// isSliceOfStructs indicates if the field type is a slice of structs.
228+
isSliceOfStructs bool
229+
// isAnonymous indicates whether the field is embedded in the struct.
230+
isAnonymous bool
231+
isRequired bool
224232
}
225233

226234
type pathPart struct {

decoder.go

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string, prefix
106106
if f.typ.Kind() == reflect.Struct {
107107
err := d.checkRequired(f.typ, src, prefix+f.alias+".")
108108
if err != nil {
109-
if !f.anon {
109+
if !f.isAnonymous {
110110
return err
111111
}
112112
// check embedded parent field.
@@ -116,7 +116,7 @@ func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string, prefix
116116
}
117117
}
118118
}
119-
if f.required {
119+
if f.isRequired {
120120
key := f.alias
121121
if prefix != "" {
122122
key = prefix + key
@@ -185,7 +185,7 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
185185
// Get the converter early in case there is one for a slice type.
186186
conv := d.cache.converter(t)
187187
m := isTextUnmarshaler(v)
188-
if conv == nil && t.Kind() == reflect.Slice && m.IsSlice {
188+
if conv == nil && t.Kind() == reflect.Slice && m.IsSliceElement {
189189
var items []reflect.Value
190190
elemT := t.Elem()
191191
isPtrElem := elemT.Kind() == reflect.Ptr
@@ -211,7 +211,7 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
211211
}
212212
} else if m.IsValid {
213213
u := reflect.New(elemT)
214-
if m.IsPtr {
214+
if m.IsSliceElementPtr {
215215
u = reflect.New(reflect.PtrTo(elemT).Elem())
216216
}
217217
if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)); err != nil {
@@ -222,7 +222,7 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
222222
Err: err,
223223
}
224224
}
225-
if m.IsPtr {
225+
if m.IsSliceElementPtr {
226226
items = append(items, u.Elem().Addr())
227227
} else if u.Kind() == reflect.Ptr {
228228
items = append(items, u.Elem())
@@ -298,14 +298,27 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
298298
}
299299
}
300300
} else if m.IsValid {
301-
// If the value implements the encoding.TextUnmarshaler interface
302-
// apply UnmarshalText as the converter
303-
if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil {
304-
return ConversionError{
305-
Key: path,
306-
Type: t,
307-
Index: -1,
308-
Err: err,
301+
if m.IsPtr {
302+
u := reflect.New(v.Type())
303+
if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(val)); err != nil {
304+
return ConversionError{
305+
Key: path,
306+
Type: t,
307+
Index: -1,
308+
Err: err,
309+
}
310+
}
311+
v.Set(reflect.Indirect(u))
312+
} else {
313+
// If the value implements the encoding.TextUnmarshaler interface
314+
// apply UnmarshalText as the converter
315+
if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil {
316+
return ConversionError{
317+
Key: path,
318+
Type: t,
319+
Index: -1,
320+
Err: err,
321+
}
309322
}
310323
}
311324
} else if conv := builtinConverters[t.Kind()]; conv != nil {
@@ -326,31 +339,36 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
326339
}
327340

328341
func isTextUnmarshaler(v reflect.Value) unmarshaler {
329-
330342
// Create a new unmarshaller instance
331343
m := unmarshaler{}
332-
333-
// As the UnmarshalText function should be applied
334-
// to the pointer of the type, we convert the value to pointer.
335-
if v.CanAddr() {
336-
v = v.Addr()
337-
}
338344
if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
339345
return m
340346
}
347+
// As the UnmarshalText function should be applied to the pointer of the
348+
// type, we check that type to see if it implements the necessary
349+
// method.
350+
if m.Unmarshaler, m.IsValid = reflect.New(v.Type()).Interface().(encoding.TextUnmarshaler); m.IsValid {
351+
m.IsPtr = true
352+
return m
353+
}
341354

342355
// if v is []T or *[]T create new T
343356
t := v.Type()
344357
if t.Kind() == reflect.Ptr {
345358
t = t.Elem()
346359
}
347360
if t.Kind() == reflect.Slice {
348-
// if t is a pointer slice, check if it implements encoding.TextUnmarshaler
349-
m.IsSlice = true
361+
// Check if the slice implements encoding.TextUnmarshaller
362+
if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
363+
return m
364+
}
365+
// If t is a pointer slice, check if its elements implement
366+
// encoding.TextUnmarshaler
367+
m.IsSliceElement = true
350368
if t = t.Elem(); t.Kind() == reflect.Ptr {
351369
t = reflect.PtrTo(t.Elem())
352370
v = reflect.Zero(t)
353-
m.IsPtr = true
371+
m.IsSliceElementPtr = true
354372
m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
355373
return m
356374
}
@@ -365,9 +383,18 @@ func isTextUnmarshaler(v reflect.Value) unmarshaler {
365383
// unmarshaller contains information about a TextUnmarshaler type
366384
type unmarshaler struct {
367385
Unmarshaler encoding.TextUnmarshaler
368-
IsSlice bool
369-
IsPtr bool
370-
IsValid bool
386+
// IsValid indicates whether the resolved type indicated by the other
387+
// flags implements the encoding.TextUnmarshaler interface.
388+
IsValid bool
389+
// IsPtr indicates that the resolved type is the pointer of the original
390+
// type.
391+
IsPtr bool
392+
// IsSliceElement indicates that the resolved type is a slice element of
393+
// the original type.
394+
IsSliceElement bool
395+
// IsSliceElementPtr indicates that the resolved type is a pointer to a
396+
// slice element of the original type.
397+
IsSliceElementPtr bool
371398
}
372399

373400
// Errors ---------------------------------------------------------------------

decoder_test.go

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1684,10 +1684,56 @@ func TestTextUnmarshalerTypeSlice(t *testing.T) {
16841684
}{}
16851685
decoder := NewDecoder()
16861686
if err := decoder.Decode(&s, data); err != nil {
1687-
t.Error("Error while decoding:", err)
1687+
t.Fatal("Error while decoding:", err)
16881688
}
16891689
expected := S20{"a", "b", "c"}
16901690
if !reflect.DeepEqual(expected, s.Value) {
16911691
t.Errorf("Expected %v errors, got %v", expected, s.Value)
16921692
}
16931693
}
1694+
1695+
type S21E struct{ ElementValue string }
1696+
1697+
func (e *S21E) UnmarshalText(text []byte) error {
1698+
*e = S21E{"x"}
1699+
return nil
1700+
}
1701+
1702+
type S21 []S21E
1703+
1704+
func (s *S21) UnmarshalText(text []byte) error {
1705+
*s = S21{{"a"}}
1706+
return nil
1707+
}
1708+
1709+
type S21B []S21E
1710+
1711+
// Test to ensure that if custom type base on a slice of structs implements an
1712+
// encoding.TextUnmarshaler interface it is unaffected by the special path
1713+
// requirements imposed on a slice of structs.
1714+
func TestTextUnmarshalerTypeSliceOfStructs(t *testing.T) {
1715+
data := map[string][]string{
1716+
"Value": []string{"raw a"},
1717+
}
1718+
// Implements encoding.TextUnmarshaler, should not throw invalid path
1719+
// error.
1720+
s := struct {
1721+
Value S21
1722+
}{}
1723+
decoder := NewDecoder()
1724+
if err := decoder.Decode(&s, data); err != nil {
1725+
t.Fatal("Error while decoding:", err)
1726+
}
1727+
expected := S21{{"a"}}
1728+
if !reflect.DeepEqual(expected, s.Value) {
1729+
t.Errorf("Expected %v errors, got %v", expected, s.Value)
1730+
}
1731+
// Does not implement encoding.TextUnmarshaler, should throw invalid
1732+
// path error.
1733+
sb := struct {
1734+
Value S21B
1735+
}{}
1736+
if err := decoder.Decode(&sb, data); err == invalidPath {
1737+
t.Fatal("Expecting invalid path error", err)
1738+
}
1739+
}

encoder_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,5 @@ func TestRegisterEncoderCustomArrayType(t *testing.T) {
415415
})
416416

417417
encoder.Encode(s, vals)
418-
t.Log(vals)
419418
}
420419
}

0 commit comments

Comments
 (0)