Skip to content

Commit 7d9bc33

Browse files
committed
schema: Allow AnyOf to implement FieldComparator
This changes the interface of FieldComparator so that it's possible for FieldValidator types to determine, either run-time or compile-time, if the validator is comparable. AnyOf has been modified to Allow comparisons if one of it's sub-validators allows it. Changes the comparable types Float, Integer, AnyOf and Time to test only the public interface; testing the private interface gives little value, as it should be allowed to change without test updates.
1 parent 2cf32f2 commit 7d9bc33

File tree

11 files changed

+242
-297
lines changed

11 files changed

+242
-297
lines changed

schema/anyof.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ func (v AnyOf) Serialize(value interface{}) (interface{}, error) {
8080
return value, nil
8181
}
8282

83+
// LessFunc implements the FieldComparator interface, and returns the first
84+
// non-nil LessFunc or nil.
85+
func (v AnyOf) LessFunc() LessFunc {
86+
for _, comparable := range v {
87+
if fc, ok := comparable.(FieldComparator); ok {
88+
if less := fc.LessFunc(); less != nil {
89+
return less
90+
}
91+
}
92+
}
93+
return nil
94+
}
95+
8396
// GetField implements the FieldGetter interface. Note that it will return the
8497
// first matching field only.
8598
func (v AnyOf) GetField(name string) *Field {

schema/anyof_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,56 @@ func TestAnyOfSerialize(t *testing.T) {
220220
cases[i].Run(t)
221221
}
222222
}
223+
224+
func TestAnyOfLesser(t *testing.T) {
225+
cases := map[string]struct {
226+
validator schema.AnyOf
227+
value, other interface{}
228+
expectNilFunc bool
229+
expectResult bool
230+
}{
231+
`AnyOf{Null,Integer}.Less(1,2)`: {
232+
validator: schema.AnyOf{&schema.Null{}, &schema.Integer{}},
233+
value: 1,
234+
other: 2,
235+
expectResult: true,
236+
},
237+
`AnyOf{Null,Integer}.Less(2,1)`: {
238+
validator: schema.AnyOf{&schema.Null{}, &schema.Integer{}},
239+
value: 2,
240+
other: 1,
241+
expectResult: false,
242+
},
243+
`AnyOf{Null,Dict}.Less(2,1)`: {
244+
validator: schema.AnyOf{&schema.Null{}, &schema.Dict{}},
245+
value: 2,
246+
other: 1,
247+
expectNilFunc: true,
248+
},
249+
}
250+
251+
for name, tt := range cases {
252+
tt := tt // capture range variable.
253+
t.Run(name, func(t *testing.T) {
254+
t.Parallel()
255+
v := &tt.validator
256+
v.Compile(nil)
257+
lessFunc := v.LessFunc()
258+
259+
if lessFunc == nil && !tt.expectNilFunc {
260+
t.Error("for validator.LessFunc(), expected non-nil result")
261+
return
262+
} else if lessFunc != nil && tt.expectNilFunc {
263+
t.Error("for validator.LessFunc(), expected nil result")
264+
return
265+
} else if lessFunc == nil {
266+
return
267+
}
268+
269+
got := lessFunc(tt.value, tt.other)
270+
if got != tt.expectResult {
271+
t.Errorf("for lessFunc(%v,%v)\ngot: %v\nwant: %v", tt.value, tt.other, got, tt.expectResult)
272+
}
273+
})
274+
}
275+
}

schema/field.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,18 @@ type FieldGetter interface {
127127
GetField(name string) *Field
128128
}
129129

130-
// FieldComparator defines an interface for comparing field values, depending
131-
// on the each value type semantics. Field types need to implement this interface
132-
// if you want to use $gt, $gte, $lt, $lte for rest-layer-mem store.
130+
// LessFunc is a function that returns true only when value is less than other,
131+
// and false in all other circumstances, including error conditions.
132+
type LessFunc func(value, other interface{}) bool
133+
134+
// FieldComparator must be implemented by a FieldValidator that is to allow
135+
// comparison queries ($gt, $gte, $lt and $lte). The returned LessFunc will be
136+
// used by the query package's Predicate.Match functions, which is used e.g. by
137+
// the internal mem storage backend.
133138
type FieldComparator interface {
134-
// Less returns true if "value" is less than "other", otherwise returns true.
135-
Less(value, other interface{}) bool
139+
// LessFunc returns a valid LessFunc or nil. nil is returned when comparison
140+
// is not allowed.
141+
LessFunc() LessFunc
136142
}
137143

138144
// FieldQueryValidator defines an interface for lightweight validation on field

schema/float.go

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,6 @@ type Float struct {
1717
Boundaries *Boundaries
1818
}
1919

20-
func (v Float) parse(value interface{}) (interface{}, error) {
21-
f, ok := value.(float64)
22-
if !ok {
23-
return nil, errors.New("not a float")
24-
}
25-
return f, nil
26-
}
27-
2820
// ValidateQuery implements schema.FieldQueryValidator interface
2921
func (v Float) ValidateQuery(value interface{}) (interface{}, error) {
3022
return v.parse(value)
@@ -68,11 +60,23 @@ func (v Float) Validate(value interface{}) (interface{}, error) {
6860
return f, nil
6961
}
7062

71-
// Less implements schema.FieldComparator interface
72-
func (v Float) Less(value, other interface{}) bool {
73-
t, err := v.get(value)
74-
o, err1 := v.get(other)
75-
if err != nil || err1 != nil {
63+
func (v Float) parse(value interface{}) (interface{}, error) {
64+
f, ok := value.(float64)
65+
if !ok {
66+
return nil, errors.New("not a float")
67+
}
68+
return f, nil
69+
}
70+
71+
// LessFunc implements the FieldComparator interface.
72+
func (v Float) LessFunc() LessFunc {
73+
return v.less
74+
}
75+
76+
func (v Float) less(value, other interface{}) bool {
77+
t, err1 := v.get(value)
78+
o, err2 := v.get(other)
79+
if err1 != nil || err2 != nil {
7680
return false
7781
}
7882
return t < o

schema/float_test.go

Lines changed: 28 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package schema
1+
package schema_test
22

33
import (
44
"errors"
@@ -7,21 +7,23 @@ import (
77
"testing"
88

99
"github.com/stretchr/testify/assert"
10+
11+
"github.com/rs/rest-layer/schema"
1012
)
1113

1214
func TestFloatQueryValidator(t *testing.T) {
1315
cases := []struct {
1416
name string
15-
field Float
17+
field schema.Float
1618
input, expect interface{}
1719
err error
1820
}{
19-
{`Float.ValidateQuery(float64)`, Float{}, 1.2, 1.2, nil},
20-
{`Float.ValidateQuery(int)`, Float{}, 1, nil, errors.New("not a float")},
21-
{`Float.ValidateQuery(string)`, Float{}, "1.2", nil, errors.New("not a float")},
22-
{"Float.ValidateQuery(float64)-out of range above", Float{Boundaries: &Boundaries{Min: 0, Max: 2}}, 3.1, 3.1, nil},
23-
{"Float.ValidateQuery(float64)-in range", Float{Boundaries: &Boundaries{Min: 0, Max: 2}}, 1.1, 1.1, nil},
24-
{"Float.ValidateQuery(float64)-out of range below", Float{Boundaries: &Boundaries{Min: 2, Max: 10}}, 1.1, 1.1, nil},
21+
{`Float.ValidateQuery(float64)`, schema.Float{}, 1.2, 1.2, nil},
22+
{`Float.ValidateQuery(int)`, schema.Float{}, 1, nil, errors.New("not a float")},
23+
{`Float.ValidateQuery(string)`, schema.Float{}, "1.2", nil, errors.New("not a float")},
24+
{"Float.ValidateQuery(float64)-out of range above", schema.Float{Boundaries: &schema.Boundaries{Min: 0, Max: 2}}, 3.1, 3.1, nil},
25+
{"Float.ValidateQuery(float64)-in range", schema.Float{Boundaries: &schema.Boundaries{Min: 0, Max: 2}}, 1.1, 1.1, nil},
26+
{"Float.ValidateQuery(float64)-out of range below", schema.Float{Boundaries: &schema.Boundaries{Min: 2, Max: 10}}, 1.1, 1.1, nil},
2527
}
2628
for i := range cases {
2729
tt := cases[i]
@@ -39,104 +41,53 @@ func TestFloatQueryValidator(t *testing.T) {
3941
}
4042

4143
func TestFloatValidator(t *testing.T) {
42-
s, err := Float{}.Validate(1.2)
44+
s, err := schema.Float{}.Validate(1.2)
4345
assert.NoError(t, err)
4446
assert.Equal(t, 1.2, s)
45-
s, err = Float{}.Validate(1)
47+
s, err = schema.Float{}.Validate(1)
4648
assert.EqualError(t, err, "not a float")
4749
assert.Nil(t, s)
48-
s, err = Float{}.Validate("1.2")
50+
s, err = schema.Float{}.Validate("1.2")
4951
assert.EqualError(t, err, "not a float")
5052
assert.Nil(t, s)
51-
s, err = Float{Boundaries: &Boundaries{Min: 0, Max: 2}}.Validate(3.1)
53+
s, err = schema.Float{Boundaries: &schema.Boundaries{Min: 0, Max: 2}}.Validate(3.1)
5254
assert.EqualError(t, err, "is greater than 2.00")
5355
assert.Nil(t, s)
54-
s, err = Float{Boundaries: &Boundaries{Min: 0, Max: 2}}.Validate(1.1)
56+
s, err = schema.Float{Boundaries: &schema.Boundaries{Min: 0, Max: 2}}.Validate(1.1)
5557
assert.NoError(t, err)
5658
assert.Equal(t, 1.1, s)
57-
s, err = Float{Boundaries: &Boundaries{Min: 2, Max: 10}}.Validate(1.1)
59+
s, err = schema.Float{Boundaries: &schema.Boundaries{Min: 2, Max: 10}}.Validate(1.1)
5860
assert.EqualError(t, err, "is lower than 2.00")
5961
assert.Nil(t, s)
60-
s, err = Float{Boundaries: &Boundaries{Min: 2, Max: 10}}.Validate(3.1)
62+
s, err = schema.Float{Boundaries: &schema.Boundaries{Min: 2, Max: 10}}.Validate(3.1)
6163
assert.NoError(t, err)
6264
assert.Equal(t, 3.1, s)
63-
s, err = Float{Boundaries: &Boundaries{Min: math.Inf(-1), Max: 10}}.Validate(3.1)
65+
s, err = schema.Float{Boundaries: &schema.Boundaries{Min: math.Inf(-1), Max: 10}}.Validate(3.1)
6466
assert.NoError(t, err)
6567
assert.Equal(t, 3.1, s)
66-
s, err = Float{Boundaries: &Boundaries{Min: math.NaN(), Max: 10}}.Validate(3.1)
68+
s, err = schema.Float{Boundaries: &schema.Boundaries{Min: math.NaN(), Max: 10}}.Validate(3.1)
6769
assert.NoError(t, err)
6870
assert.Equal(t, 3.1, s)
69-
s, err = Float{Boundaries: &Boundaries{Min: 2, Max: math.Inf(1)}}.Validate(3.1)
71+
s, err = schema.Float{Boundaries: &schema.Boundaries{Min: 2, Max: math.Inf(1)}}.Validate(3.1)
7072
assert.NoError(t, err)
7173
assert.Equal(t, 3.1, s)
72-
s, err = Float{Boundaries: &Boundaries{Min: 2, Max: math.NaN()}}.Validate(3.1)
74+
s, err = schema.Float{Boundaries: &schema.Boundaries{Min: 2, Max: math.NaN()}}.Validate(3.1)
7375
assert.NoError(t, err)
7476
assert.Equal(t, 3.1, s)
75-
s, err = Float{Boundaries: &Boundaries{}}.Validate(1.1)
77+
s, err = schema.Float{Boundaries: &schema.Boundaries{}}.Validate(1.1)
7678
assert.EqualError(t, err, "is greater than 0.00")
7779
assert.Nil(t, s)
78-
s, err = Float{Boundaries: &Boundaries{}}.Validate(-1.1)
80+
s, err = schema.Float{Boundaries: &schema.Boundaries{}}.Validate(-1.1)
7981
assert.EqualError(t, err, "is lower than 0.00")
8082
assert.Nil(t, s)
81-
s, err = Float{Allowed: []float64{.1, .2, .3}}.Validate(.4)
83+
s, err = schema.Float{Allowed: []float64{.1, .2, .3}}.Validate(.4)
8284
assert.EqualError(t, err, "not one of the allowed values")
8385
assert.Nil(t, s)
84-
s, err = Float{Allowed: []float64{.1, .2, .3}}.Validate(.2)
86+
s, err = schema.Float{Allowed: []float64{.1, .2, .3}}.Validate(.2)
8587
assert.NoError(t, err)
8688
assert.Equal(t, .2, s)
8789
}
8890

89-
func TestFloatParse(t *testing.T) {
90-
cases := []struct {
91-
name string
92-
input, expect interface{}
93-
err error
94-
}{
95-
{`Float.parse(float64)`, 1.2, 1.2, nil},
96-
{`Float.parse(int)`, 1, nil, errors.New("not a float")},
97-
{`Float.parse(string)`, "1.2", nil, errors.New("not a float")},
98-
}
99-
for i := range cases {
100-
tt := cases[i]
101-
t.Run(tt.name, func(t *testing.T) {
102-
t.Parallel()
103-
got, err := Float{}.parse(tt.input)
104-
if !reflect.DeepEqual(err, tt.err) {
105-
t.Errorf("unexpected error for `%v`\ngot: %v\nwant: %v", tt.input, err, tt.err)
106-
}
107-
if !reflect.DeepEqual(got, tt.expect) {
108-
t.Errorf("invalid output for `%v`:\ngot: %#v\nwant: %#v", tt.input, got, tt.expect)
109-
}
110-
})
111-
}
112-
}
113-
114-
func TestFloatGet(t *testing.T) {
115-
cases := []struct {
116-
name string
117-
field Float
118-
input, expect interface{}
119-
err error
120-
}{
121-
{`Float.get(float64)`, Float{}, 1.2, 1.2, nil},
122-
{`Float.get(int)`, Float{}, 1, 0.0, errors.New("not a float")},
123-
{`Float.get(string)`, Float{}, "1.2", 0.0, errors.New("not a float")},
124-
}
125-
for i := range cases {
126-
tt := cases[i]
127-
t.Run(tt.name, func(t *testing.T) {
128-
t.Parallel()
129-
got, err := (tt.field).get(tt.input)
130-
if !reflect.DeepEqual(err, tt.err) {
131-
t.Errorf("unexpected error for `%v`\ngot: %v\nwant: %v", tt.input, err, tt.err)
132-
}
133-
if !reflect.DeepEqual(got, tt.expect) {
134-
t.Errorf("invalid output for `%v`:\ngot: %#v\nwant: %#v", tt.input, got, tt.expect)
135-
}
136-
})
137-
}
138-
}
139-
14091
func TestFloatLesser(t *testing.T) {
14192
cases := []struct {
14293
name string
@@ -148,11 +99,13 @@ func TestFloatLesser(t *testing.T) {
14899
{`Float.Less(2.0,1.0)`, 2.0, 1.0, false},
149100
{`Float.Less(1.0,"2.0")`, 1.0, "2.0", false},
150101
}
102+
lessFunc := schema.Float{}.LessFunc()
103+
151104
for i := range cases {
152105
tt := cases[i]
153106
t.Run(tt.name, func(t *testing.T) {
154107
t.Parallel()
155-
got := Float{}.Less(tt.value, tt.other)
108+
got := lessFunc(tt.value, tt.other)
156109
if got != tt.expected {
157110
t.Errorf("output for `%v`\ngot: %v\nwant: %v", tt.name, got, tt.expected)
158111
}

schema/integer.go

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,6 @@ type Integer struct {
1212
Boundaries *Boundaries
1313
}
1414

15-
func (v Integer) parse(value interface{}) (interface{}, error) {
16-
if f, ok := value.(float64); ok {
17-
// JSON unmarshaling treat all numbers as float64, try to convert it to
18-
// int if not fraction.
19-
i, frac := math.Modf(f)
20-
if frac == 0.0 {
21-
v := int(i)
22-
value = v
23-
}
24-
}
25-
i, ok := value.(int)
26-
if !ok {
27-
return nil, errors.New("not an integer")
28-
}
29-
return i, nil
30-
}
31-
3215
// ValidateQuery implements schema.FieldQueryValidator interface
3316
func (v Integer) ValidateQuery(value interface{}) (interface{}, error) {
3417
return v.parse(value)
@@ -76,11 +59,32 @@ func (v Integer) Validate(value interface{}) (interface{}, error) {
7659
return i, nil
7760
}
7861

79-
// Less implements schema.FieldComparator interface
80-
func (v Integer) Less(value, other interface{}) bool {
81-
t, err := v.get(value)
82-
o, err1 := v.get(other)
83-
if err != nil || err1 != nil {
62+
func (v Integer) parse(value interface{}) (interface{}, error) {
63+
if f, ok := value.(float64); ok {
64+
// JSON unmarshaling treat all numbers as float64, try to convert it to
65+
// int if not fraction.
66+
i, frac := math.Modf(f)
67+
if frac == 0.0 {
68+
v := int(i)
69+
value = v
70+
}
71+
}
72+
i, ok := value.(int)
73+
if !ok {
74+
return nil, errors.New("not an integer")
75+
}
76+
return i, nil
77+
}
78+
79+
// LessFunc implements the FieldComparator interface.
80+
func (v Integer) LessFunc() LessFunc {
81+
return v.less
82+
}
83+
84+
func (v Integer) less(value, other interface{}) bool {
85+
t, err1 := v.get(value)
86+
o, err2 := v.get(other)
87+
if err1 != nil || err2 != nil {
8488
return false
8589
}
8690
return t < o

0 commit comments

Comments
 (0)