Skip to content

Commit f009312

Browse files
authored
Merge pull request #1191 from lark-parser/adjust_pr1152
Adjustments for PR #1152
2 parents f775df3 + dce017c commit f009312

File tree

3 files changed

+59
-45
lines changed

3 files changed

+59
-45
lines changed

lark/lark.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys, os, pickle, hashlib
44
import tempfile
55
import types
6+
import re
67
from typing import (
78
TypeVar, Type, List, Dict, Iterator, Callable, Union, Optional, Sequence,
89
Tuple, Iterable, IO, Any, TYPE_CHECKING, Collection
@@ -15,6 +16,7 @@
1516
from typing import Literal
1617
else:
1718
from typing_extensions import Literal
19+
from .parser_frontends import ParsingFrontend
1820

1921
from .exceptions import ConfigurationError, assert_config, UnexpectedInput
2022
from .utils import Serialize, SerializeMemoizer, FS, isascii, logger
@@ -27,7 +29,7 @@
2729
from .parser_frontends import _validate_frontend_args, _get_lexer_callbacks, _deserialize_parsing_frontend, _construct_parsing_frontend
2830
from .grammar import Rule
2931

30-
import re
32+
3133
try:
3234
import regex
3335
_has_regex = True
@@ -176,7 +178,7 @@ class LarkOptions(Serialize):
176178
'_plugins': {},
177179
}
178180

179-
def __init__(self, options_dict):
181+
def __init__(self, options_dict: Dict[str, Any]) -> None:
180182
o = dict(options_dict)
181183

182184
options = {}
@@ -205,21 +207,21 @@ def __init__(self, options_dict):
205207
if o:
206208
raise ConfigurationError("Unknown options: %s" % o.keys())
207209

208-
def __getattr__(self, name):
210+
def __getattr__(self, name: str) -> Any:
209211
try:
210212
return self.__dict__['options'][name]
211213
except KeyError as e:
212214
raise AttributeError(e)
213215

214-
def __setattr__(self, name, value):
216+
def __setattr__(self, name: str, value: str) -> None:
215217
assert_config(name, self.options.keys(), "%r isn't a valid option. Expected one of: %s")
216218
self.options[name] = value
217219

218-
def serialize(self, memo):
220+
def serialize(self, memo = None) -> Dict[str, Any]:
219221
return self.options
220222

221223
@classmethod
222-
def deserialize(cls, data, memo):
224+
def deserialize(cls, data: Dict[str, Any], memo: Dict[int, Union[TerminalDef, Rule]]) -> "LarkOptions":
223225
return cls(data)
224226

225227

@@ -252,7 +254,7 @@ class Lark(Serialize):
252254
grammar: 'Grammar'
253255
options: LarkOptions
254256
lexer: Lexer
255-
terminals: List[TerminalDef]
257+
terminals: Collection[TerminalDef]
256258

257259
def __init__(self, grammar: 'Union[Grammar, str, IO[str]]', **options) -> None:
258260
self.options = LarkOptions(options)
@@ -446,15 +448,15 @@ def __init__(self, grammar: 'Union[Grammar, str, IO[str]]', **options) -> None:
446448

447449
__serialize_fields__ = 'parser', 'rules', 'options'
448450

449-
def _build_lexer(self, dont_ignore=False):
451+
def _build_lexer(self, dont_ignore: bool=False) -> BasicLexer:
450452
lexer_conf = self.lexer_conf
451453
if dont_ignore:
452454
from copy import copy
453455
lexer_conf = copy(lexer_conf)
454456
lexer_conf.ignore = ()
455457
return BasicLexer(lexer_conf)
456458

457-
def _prepare_callbacks(self):
459+
def _prepare_callbacks(self) -> None:
458460
self._callbacks = {}
459461
# we don't need these callbacks if we aren't building a tree
460462
if self.options.ambiguity != 'forest':
@@ -468,7 +470,7 @@ def _prepare_callbacks(self):
468470
self._callbacks = self._parse_tree_builder.create_callback(self.options.transformer)
469471
self._callbacks.update(_get_lexer_callbacks(self.options.transformer, self.terminals))
470472

471-
def _build_parser(self):
473+
def _build_parser(self) -> "ParsingFrontend":
472474
self._prepare_callbacks()
473475
_validate_frontend_args(self.options.parser, self.options.lexer)
474476
parser_conf = ParserConf(self.rules, self._callbacks, self.options.start)
@@ -480,7 +482,7 @@ def _build_parser(self):
480482
options=self.options
481483
)
482484

483-
def save(self, f, exclude_options: Collection[str] = ()):
485+
def save(self, f, exclude_options: Collection[str] = ()) -> None:
484486
"""Saves the instance into the given file object
485487
486488
Useful for caching and multiprocessing.
@@ -491,15 +493,15 @@ def save(self, f, exclude_options: Collection[str] = ()):
491493
pickle.dump({'data': data, 'memo': m}, f, protocol=pickle.HIGHEST_PROTOCOL)
492494

493495
@classmethod
494-
def load(cls, f):
496+
def load(cls: Type[_T], f) -> _T:
495497
"""Loads an instance from the given file object
496498
497499
Useful for caching and multiprocessing.
498500
"""
499501
inst = cls.__new__(cls)
500502
return inst._load(f)
501503

502-
def _deserialize_lexer_conf(self, data, memo, options):
504+
def _deserialize_lexer_conf(self, data: Dict[str, Any], memo: Dict[int, Union[TerminalDef, Rule]], options: LarkOptions) -> LexerConf:
503505
lexer_conf = LexerConf.deserialize(data['lexer_conf'], memo)
504506
lexer_conf.callbacks = options.lexer_callbacks or {}
505507
lexer_conf.re_module = regex if options.regex else re
@@ -509,7 +511,7 @@ def _deserialize_lexer_conf(self, data, memo, options):
509511
lexer_conf.postlex = options.postlex
510512
return lexer_conf
511513

512-
def _load(self, f, **kwargs):
514+
def _load(self: _T, f: Any, **kwargs) -> _T:
513515
if isinstance(f, dict):
514516
d = f
515517
else:
@@ -593,6 +595,7 @@ def lex(self, text: str, dont_ignore: bool=False) -> Iterator[Token]:
593595
594596
:raises UnexpectedCharacters: In case the lexer cannot find a suitable match.
595597
"""
598+
lexer: Lexer
596599
if not hasattr(self, 'lexer') or dont_ignore:
597600
lexer = self._build_lexer(dont_ignore)
598601
else:

lark/parsers/lalr_parser.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Author: Erez Shinan (2017)
44
# Email : erezshin@gmail.com
55
from copy import deepcopy, copy
6+
from typing import Dict, Any
67
from ..lexer import Token
78
from ..utils import Serialize
89

@@ -29,7 +30,7 @@ def deserialize(cls, data, memo, callbacks, debug=False):
2930
inst.parser = _Parser(inst._parse_table, callbacks, debug)
3031
return inst
3132

32-
def serialize(self, memo):
33+
def serialize(self, memo: Any = None) -> Dict[str, Any]:
3334
return self._parse_table.serialize(memo)
3435

3536
def parse_interactive(self, lexer, start):

lark/utils.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import os
33
from functools import reduce
44
from collections import deque
5+
from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence
56

67
###{standalone
78
import sys, re
89
import logging
10+
911
logger: logging.Logger = logging.getLogger("lark")
1012
logger.addHandler(logging.StreamHandler())
1113
# Set to highest level, since we have some warnings amongst the code
@@ -15,9 +17,11 @@
1517

1618
NO_VALUE = object()
1719

20+
T = TypeVar("T")
21+
1822

19-
def classify(seq, key=None, value=None):
20-
d = {}
23+
def classify(seq: Sequence, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
24+
d: Dict[Any, Any] = {}
2125
for item in seq:
2226
k = key(item) if (key is not None) else item
2327
v = value(item) if (value is not None) else item
@@ -28,7 +32,7 @@ def classify(seq, key=None, value=None):
2832
return d
2933

3034

31-
def _deserialize(data, namespace, memo):
35+
def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any:
3236
if isinstance(data, dict):
3337
if '__type__' in data: # Object
3438
class_ = namespace[data['__type__']]
@@ -41,6 +45,8 @@ def _deserialize(data, namespace, memo):
4145
return data
4246

4347

48+
_T = TypeVar("_T", bound="Serialize")
49+
4450
class Serialize:
4551
"""Safe-ish serialization interface that doesn't rely on Pickle
4652
@@ -50,23 +56,23 @@ class Serialize:
5056
Should include all field types that aren't builtin types.
5157
"""
5258

53-
def memo_serialize(self, types_to_memoize):
59+
def memo_serialize(self, types_to_memoize: List) -> Any:
5460
memo = SerializeMemoizer(types_to_memoize)
5561
return self.serialize(memo), memo.serialize()
5662

57-
def serialize(self, memo=None):
63+
def serialize(self, memo = None) -> Dict[str, Any]:
5864
if memo and memo.in_types(self):
5965
return {'@': memo.memoized.get(self)}
6066

6167
fields = getattr(self, '__serialize_fields__')
6268
res = {f: _serialize(getattr(self, f), memo) for f in fields}
6369
res['__type__'] = type(self).__name__
6470
if hasattr(self, '_serialize'):
65-
self._serialize(res, memo)
71+
self._serialize(res, memo) # type: ignore[attr-defined]
6672
return res
6773

6874
@classmethod
69-
def deserialize(cls, data, memo):
75+
def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T:
7076
namespace = getattr(cls, '__serialize_namespace__', [])
7177
namespace = {c.__name__:c for c in namespace}
7278

@@ -83,7 +89,7 @@ def deserialize(cls, data, memo):
8389
raise KeyError("Cannot find key for class", cls, e)
8490

8591
if hasattr(inst, '_deserialize'):
86-
inst._deserialize()
92+
inst._deserialize() # type: ignore[attr-defined]
8793

8894
return inst
8995

@@ -93,18 +99,18 @@ class SerializeMemoizer(Serialize):
9399

94100
__serialize_fields__ = 'memoized',
95101

96-
def __init__(self, types_to_memoize):
102+
def __init__(self, types_to_memoize: List) -> None:
97103
self.types_to_memoize = tuple(types_to_memoize)
98104
self.memoized = Enumerator()
99105

100-
def in_types(self, value):
106+
def in_types(self, value: Serialize) -> bool:
101107
return isinstance(value, self.types_to_memoize)
102108

103-
def serialize(self):
109+
def serialize(self) -> Dict[int, Any]: # type: ignore[override]
104110
return _serialize(self.memoized.reversed(), None)
105111

106112
@classmethod
107-
def deserialize(cls, data, namespace, memo):
113+
def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override]
108114
return _deserialize(data, namespace, memo)
109115

110116

@@ -123,7 +129,7 @@ def deserialize(cls, data, namespace, memo):
123129

124130
categ_pattern = re.compile(r'\\p{[A-Za-z_]+}')
125131

126-
def get_regexp_width(expr):
132+
def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]:
127133
if _has_regex:
128134
# Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with
129135
# a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex
@@ -134,7 +140,8 @@ def get_regexp_width(expr):
134140
raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr)
135141
regexp_final = expr
136142
try:
137-
return [int(x) for x in sre_parse.parse(regexp_final).getwidth()]
143+
# Fixed in next version (past 0.960) of typeshed
144+
return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] # type: ignore[attr-defined]
138145
except sre_constants.error:
139146
if not _has_regex:
140147
raise ValueError(expr)
@@ -154,47 +161,50 @@ def get_regexp_width(expr):
154161
_ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc'
155162
_ID_CONTINUE = _ID_START + ('Nd', 'Nl',)
156163

157-
def _test_unicode_category(s, categories):
164+
def _test_unicode_category(s: str, categories: Sequence[str]) -> bool:
158165
if len(s) != 1:
159166
return all(_test_unicode_category(char, categories) for char in s)
160167
return s == '_' or unicodedata.category(s) in categories
161168

162-
def is_id_continue(s):
169+
def is_id_continue(s: str) -> bool:
163170
"""
164171
Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin
165172
numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details.
166173
"""
167174
return _test_unicode_category(s, _ID_CONTINUE)
168175

169-
def is_id_start(s):
176+
def is_id_start(s: str) -> bool:
170177
"""
171178
Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin
172179
numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details.
173180
"""
174181
return _test_unicode_category(s, _ID_START)
175182

176183

177-
def dedup_list(l):
184+
def dedup_list(l: List[T]) -> List[T]:
178185
"""Given a list (l) will removing duplicates from the list,
179186
preserving the original order of the list. Assumes that
180187
the list entries are hashable."""
181188
dedup = set()
182-
return [x for x in l if not (x in dedup or dedup.add(x))]
189+
# This returns None, but that's expected
190+
return [x for x in l if not (x in dedup or dedup.add(x))] # type: ignore[func-returns-value]
191+
# 2x faster (ordered in PyPy and CPython 3.6+, gaurenteed to be ordered in Python 3.7+)
192+
# return list(dict.fromkeys(l))
183193

184194

185195
class Enumerator(Serialize):
186-
def __init__(self):
187-
self.enums = {}
196+
def __init__(self) -> None:
197+
self.enums: Dict[Any, int] = {}
188198

189-
def get(self, item):
199+
def get(self, item) -> int:
190200
if item not in self.enums:
191201
self.enums[item] = len(self.enums)
192202
return self.enums[item]
193203

194204
def __len__(self):
195205
return len(self.enums)
196206

197-
def reversed(self):
207+
def reversed(self) -> Dict[int, Any]:
198208
r = {v: k for k, v in self.enums.items()}
199209
assert len(r) == len(self.enums)
200210
return r
@@ -240,11 +250,11 @@ def open(name, mode="r", **kwargs):
240250

241251

242252

243-
def isascii(s):
253+
def isascii(s: str) -> bool:
244254
""" str.isascii only exists in python3.7+ """
245-
try:
255+
if sys.version_info >= (3, 7):
246256
return s.isascii()
247-
except AttributeError:
257+
else:
248258
try:
249259
s.encode('ascii')
250260
return True
@@ -257,7 +267,7 @@ def __repr__(self):
257267
return '{%s}' % ', '.join(map(repr, self))
258268

259269

260-
def classify_bool(seq, pred):
270+
def classify_bool(seq: Sequence, pred: Callable) -> Any:
261271
true_elems = []
262272
false_elems = []
263273

@@ -270,7 +280,7 @@ def classify_bool(seq, pred):
270280
return true_elems, false_elems
271281

272282

273-
def bfs(initial, expand):
283+
def bfs(initial: Sequence, expand: Callable) -> Iterator:
274284
open_q = deque(list(initial))
275285
visited = set(open_q)
276286
while open_q:
@@ -290,7 +300,7 @@ def bfs_all_unique(initial, expand):
290300
open_q += expand(node)
291301

292302

293-
def _serialize(value, memo):
303+
def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any:
294304
if isinstance(value, Serialize):
295305
return value.serialize(memo)
296306
elif isinstance(value, list):
@@ -305,7 +315,7 @@ def _serialize(value, memo):
305315

306316

307317

308-
def small_factors(n, max_factor):
318+
def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]:
309319
"""
310320
Splits n up into smaller factors and summands <= max_factor.
311321
Returns a list of [(a, b), ...]

0 commit comments

Comments
 (0)