2
2
import os
3
3
from functools import reduce
4
4
from collections import deque
5
+ from typing import Callable , Iterator , List , Optional , Tuple , Type , TypeVar , Union , Dict , Any , Sequence
5
6
6
7
###{standalone
7
8
import sys , re
8
9
import logging
10
+
9
11
logger : logging .Logger = logging .getLogger ("lark" )
10
12
logger .addHandler (logging .StreamHandler ())
11
13
# Set to highest level, since we have some warnings amongst the code
15
17
16
18
NO_VALUE = object ()
17
19
20
+ T = TypeVar ("T" )
21
+
18
22
19
- def classify (seq , key = None , value = None ):
20
- d = {}
23
+ def classify (seq : Sequence , key : Optional [ Callable ] = None , value : Optional [ Callable ] = None ) -> Dict :
24
+ d : Dict [ Any , Any ] = {}
21
25
for item in seq :
22
26
k = key (item ) if (key is not None ) else item
23
27
v = value (item ) if (value is not None ) else item
@@ -28,7 +32,7 @@ def classify(seq, key=None, value=None):
28
32
return d
29
33
30
34
31
- def _deserialize (data , namespace , memo ) :
35
+ def _deserialize (data : Any , namespace : Dict [ str , Any ], memo : Dict ) -> Any :
32
36
if isinstance (data , dict ):
33
37
if '__type__' in data : # Object
34
38
class_ = namespace [data ['__type__' ]]
@@ -41,6 +45,8 @@ def _deserialize(data, namespace, memo):
41
45
return data
42
46
43
47
48
+ _T = TypeVar ("_T" , bound = "Serialize" )
49
+
44
50
class Serialize :
45
51
"""Safe-ish serialization interface that doesn't rely on Pickle
46
52
@@ -50,23 +56,23 @@ class Serialize:
50
56
Should include all field types that aren't builtin types.
51
57
"""
52
58
53
- def memo_serialize (self , types_to_memoize ) :
59
+ def memo_serialize (self , types_to_memoize : List ) -> Any :
54
60
memo = SerializeMemoizer (types_to_memoize )
55
61
return self .serialize (memo ), memo .serialize ()
56
62
57
- def serialize (self , memo = None ):
63
+ def serialize (self , memo = None ) -> Dict [ str , Any ] :
58
64
if memo and memo .in_types (self ):
59
65
return {'@' : memo .memoized .get (self )}
60
66
61
67
fields = getattr (self , '__serialize_fields__' )
62
68
res = {f : _serialize (getattr (self , f ), memo ) for f in fields }
63
69
res ['__type__' ] = type (self ).__name__
64
70
if hasattr (self , '_serialize' ):
65
- self ._serialize (res , memo )
71
+ self ._serialize (res , memo ) # type: ignore[attr-defined]
66
72
return res
67
73
68
74
@classmethod
69
- def deserialize (cls , data , memo ) :
75
+ def deserialize (cls : Type [ _T ] , data : Dict [ str , Any ], memo : Dict [ int , Any ]) -> _T :
70
76
namespace = getattr (cls , '__serialize_namespace__' , [])
71
77
namespace = {c .__name__ :c for c in namespace }
72
78
@@ -83,7 +89,7 @@ def deserialize(cls, data, memo):
83
89
raise KeyError ("Cannot find key for class" , cls , e )
84
90
85
91
if hasattr (inst , '_deserialize' ):
86
- inst ._deserialize ()
92
+ inst ._deserialize () # type: ignore[attr-defined]
87
93
88
94
return inst
89
95
@@ -93,18 +99,18 @@ class SerializeMemoizer(Serialize):
93
99
94
100
__serialize_fields__ = 'memoized' ,
95
101
96
- def __init__ (self , types_to_memoize ) :
102
+ def __init__ (self , types_to_memoize : List ) -> None :
97
103
self .types_to_memoize = tuple (types_to_memoize )
98
104
self .memoized = Enumerator ()
99
105
100
- def in_types (self , value ) :
106
+ def in_types (self , value : Serialize ) -> bool :
101
107
return isinstance (value , self .types_to_memoize )
102
108
103
- def serialize (self ):
109
+ def serialize (self ) -> Dict [ int , Any ]: # type: ignore[override]
104
110
return _serialize (self .memoized .reversed (), None )
105
111
106
112
@classmethod
107
- def deserialize (cls , data , namespace , memo ):
113
+ def deserialize (cls , data : Dict [ int , Any ], namespace : Dict [ str , Any ], memo : Dict [ Any , Any ]) -> Dict [ int , Any ]: # type: ignore[override]
108
114
return _deserialize (data , namespace , memo )
109
115
110
116
@@ -123,7 +129,7 @@ def deserialize(cls, data, namespace, memo):
123
129
124
130
categ_pattern = re .compile (r'\\p{[A-Za-z_]+}' )
125
131
126
- def get_regexp_width (expr ) :
132
+ def get_regexp_width (expr : str ) -> Union [ Tuple [ int , int ], List [ int ]] :
127
133
if _has_regex :
128
134
# Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with
129
135
# a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex
@@ -134,7 +140,8 @@ def get_regexp_width(expr):
134
140
raise ImportError ('`regex` module must be installed in order to use Unicode categories.' , expr )
135
141
regexp_final = expr
136
142
try :
137
- return [int (x ) for x in sre_parse .parse (regexp_final ).getwidth ()]
143
+ # Fixed in next version (past 0.960) of typeshed
144
+ return [int (x ) for x in sre_parse .parse (regexp_final ).getwidth ()] # type: ignore[attr-defined]
138
145
except sre_constants .error :
139
146
if not _has_regex :
140
147
raise ValueError (expr )
@@ -154,47 +161,50 @@ def get_regexp_width(expr):
154
161
_ID_START = 'Lu' , 'Ll' , 'Lt' , 'Lm' , 'Lo' , 'Mn' , 'Mc' , 'Pc'
155
162
_ID_CONTINUE = _ID_START + ('Nd' , 'Nl' ,)
156
163
157
- def _test_unicode_category (s , categories ) :
164
+ def _test_unicode_category (s : str , categories : Sequence [ str ]) -> bool :
158
165
if len (s ) != 1 :
159
166
return all (_test_unicode_category (char , categories ) for char in s )
160
167
return s == '_' or unicodedata .category (s ) in categories
161
168
162
- def is_id_continue (s ) :
169
+ def is_id_continue (s : str ) -> bool :
163
170
"""
164
171
Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin
165
172
numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details.
166
173
"""
167
174
return _test_unicode_category (s , _ID_CONTINUE )
168
175
169
- def is_id_start (s ) :
176
+ def is_id_start (s : str ) -> bool :
170
177
"""
171
178
Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin
172
179
numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details.
173
180
"""
174
181
return _test_unicode_category (s , _ID_START )
175
182
176
183
177
- def dedup_list (l ) :
184
+ def dedup_list (l : List [ T ]) -> List [ T ] :
178
185
"""Given a list (l) will removing duplicates from the list,
179
186
preserving the original order of the list. Assumes that
180
187
the list entries are hashable."""
181
188
dedup = set ()
182
- return [x for x in l if not (x in dedup or dedup .add (x ))]
189
+ # This returns None, but that's expected
190
+ return [x for x in l if not (x in dedup or dedup .add (x ))] # type: ignore[func-returns-value]
191
+ # 2x faster (ordered in PyPy and CPython 3.6+, gaurenteed to be ordered in Python 3.7+)
192
+ # return list(dict.fromkeys(l))
183
193
184
194
185
195
class Enumerator (Serialize ):
186
- def __init__ (self ):
187
- self .enums = {}
196
+ def __init__ (self ) -> None :
197
+ self .enums : Dict [ Any , int ] = {}
188
198
189
- def get (self , item ):
199
+ def get (self , item ) -> int :
190
200
if item not in self .enums :
191
201
self .enums [item ] = len (self .enums )
192
202
return self .enums [item ]
193
203
194
204
def __len__ (self ):
195
205
return len (self .enums )
196
206
197
- def reversed (self ):
207
+ def reversed (self ) -> Dict [ int , Any ] :
198
208
r = {v : k for k , v in self .enums .items ()}
199
209
assert len (r ) == len (self .enums )
200
210
return r
@@ -240,11 +250,11 @@ def open(name, mode="r", **kwargs):
240
250
241
251
242
252
243
- def isascii (s ) :
253
+ def isascii (s : str ) -> bool :
244
254
""" str.isascii only exists in python3.7+ """
245
- try :
255
+ if sys . version_info >= ( 3 , 7 ) :
246
256
return s .isascii ()
247
- except AttributeError :
257
+ else :
248
258
try :
249
259
s .encode ('ascii' )
250
260
return True
@@ -257,7 +267,7 @@ def __repr__(self):
257
267
return '{%s}' % ', ' .join (map (repr , self ))
258
268
259
269
260
- def classify_bool (seq , pred ) :
270
+ def classify_bool (seq : Sequence , pred : Callable ) -> Any :
261
271
true_elems = []
262
272
false_elems = []
263
273
@@ -270,7 +280,7 @@ def classify_bool(seq, pred):
270
280
return true_elems , false_elems
271
281
272
282
273
- def bfs (initial , expand ) :
283
+ def bfs (initial : Sequence , expand : Callable ) -> Iterator :
274
284
open_q = deque (list (initial ))
275
285
visited = set (open_q )
276
286
while open_q :
@@ -290,7 +300,7 @@ def bfs_all_unique(initial, expand):
290
300
open_q += expand (node )
291
301
292
302
293
- def _serialize (value , memo ) :
303
+ def _serialize (value : Any , memo : Optional [ SerializeMemoizer ]) -> Any :
294
304
if isinstance (value , Serialize ):
295
305
return value .serialize (memo )
296
306
elif isinstance (value , list ):
@@ -305,7 +315,7 @@ def _serialize(value, memo):
305
315
306
316
307
317
308
- def small_factors (n , max_factor ) :
318
+ def small_factors (n : int , max_factor : int ) -> List [ Tuple [ int , int ]] :
309
319
"""
310
320
Splits n up into smaller factors and summands <= max_factor.
311
321
Returns a list of [(a, b), ...]
0 commit comments