1
1
// Validator for things inside of a typing.Literal[]
2
2
// which can be an int, a string, bytes or an Enum value (including `class Foo(str, Enum)` type enums)
3
3
use core:: fmt:: Debug ;
4
- use std:: cmp:: Ordering ;
5
4
6
5
use pyo3:: prelude:: * ;
7
6
use pyo3:: types:: { PyDict , PyInt , PyList } ;
@@ -35,7 +34,7 @@ pub struct LiteralLookup<T: Debug> {
35
34
// Catch all for hashable types like Enum and bytes (the latter only because it is seldom used)
36
35
expected_py_dict : Option < Py < PyDict > > ,
37
36
// Catch all for unhashable types like list
38
- expected_py_list : Option < Py < PyList > > ,
37
+ expected_py_values : Option < Vec < ( Py < PyAny > , usize ) > > ,
39
38
40
39
pub values : Vec < T > ,
41
40
}
@@ -46,7 +45,7 @@ impl<T: Debug> LiteralLookup<T> {
46
45
let mut expected_int = AHashMap :: new ( ) ;
47
46
let mut expected_str: AHashMap < String , usize > = AHashMap :: new ( ) ;
48
47
let expected_py_dict = PyDict :: new_bound ( py) ;
49
- let expected_py_list = PyList :: empty_bound ( py ) ;
48
+ let mut expected_py_values = Vec :: new ( ) ;
50
49
let mut values = Vec :: new ( ) ;
51
50
for ( k, v) in expected {
52
51
let id = values. len ( ) ;
@@ -71,7 +70,7 @@ impl<T: Debug> LiteralLookup<T> {
71
70
. map_err ( |_| py_schema_error_type ! ( "error extracting str {:?}" , k) ) ?;
72
71
expected_str. insert ( str. to_string ( ) , id) ;
73
72
} else if expected_py_dict. set_item ( & k, id) . is_err ( ) {
74
- expected_py_list . append ( ( & k , id) ) ? ;
73
+ expected_py_values . push ( ( k . as_unbound ( ) . clone_ref ( py ) , id) ) ;
75
74
}
76
75
}
77
76
@@ -92,9 +91,9 @@ impl<T: Debug> LiteralLookup<T> {
92
91
true => None ,
93
92
false => Some ( expected_py_dict. into ( ) ) ,
94
93
} ,
95
- expected_py_list : match expected_py_list . is_empty ( ) {
94
+ expected_py_values : match expected_py_values . is_empty ( ) {
96
95
true => None ,
97
- false => Some ( expected_py_list . into ( ) ) ,
96
+ false => Some ( expected_py_values ) ,
98
97
} ,
99
98
values,
100
99
} )
@@ -143,23 +142,23 @@ impl<T: Debug> LiteralLookup<T> {
143
142
}
144
143
}
145
144
}
145
+ // cache py_input if needed, since we might need it for multiple lookups
146
+ let mut py_input = None ;
146
147
if let Some ( expected_py_dict) = & self . expected_py_dict {
148
+ let py_input = py_input. get_or_insert_with ( || input. to_object ( py) ) ;
147
149
// We don't use ? to unpack the result of `get_item` in the next line because unhashable
148
150
// inputs will produce a TypeError, which in this case we just want to treat equivalently
149
151
// to a failed lookup
150
- if let Ok ( Some ( v) ) = expected_py_dict. bind ( py) . get_item ( input ) {
152
+ if let Ok ( Some ( v) ) = expected_py_dict. bind ( py) . get_item ( & * py_input ) {
151
153
let id: usize = v. extract ( ) . unwrap ( ) ;
152
154
return Ok ( Some ( ( input, & self . values [ id] ) ) ) ;
153
155
}
154
156
} ;
155
- if let Some ( expected_py_list) = & self . expected_py_list {
156
- for item in expected_py_list. bind ( py) {
157
- let ( k, id) : ( Bound < PyAny > , usize ) = item. extract ( ) ?;
158
- if k. compare ( input. to_object ( py) . bind ( py) )
159
- . unwrap_or ( Ordering :: Less )
160
- . is_eq ( )
161
- {
162
- return Ok ( Some ( ( input, & self . values [ id] ) ) ) ;
157
+ if let Some ( expected_py_values) = & self . expected_py_values {
158
+ let py_input = py_input. get_or_insert_with ( || input. to_object ( py) ) ;
159
+ for ( k, id) in expected_py_values {
160
+ if k. bind ( py) . eq ( & * py_input) . unwrap_or ( false ) {
161
+ return Ok ( Some ( ( input, & self . values [ * id] ) ) ) ;
163
162
}
164
163
}
165
164
} ;
0 commit comments