| 
4 | 4 | 
 
  | 
5 | 5 | from .._common import group_consecutive_idx  | 
6 | 6 | from ..dataclasses import ParsedIngredient, ParserDebugInfo  | 
 | 7 | +from ._constants import FORBIDDEN_TRANSITIONS, LABELS  | 
7 | 8 | from ._loaders import load_parser_model  | 
8 | 9 | from ._utils import pluralise_units  | 
9 | 10 | from .postprocess import PostProcessor  | 
@@ -275,49 +276,6 @@ def guess_ingredient_name(  | 
275 | 276 |     return labels, scores  | 
276 | 277 | 
 
  | 
277 | 278 | 
 
  | 
278 |  | -# Dict of illegal transitions.  | 
279 |  | -# The key is the previous label, the values are the set of labels that cannot be  | 
280 |  | -# predicted for the next label.  | 
281 |  | -# This are generated from the training data: these transition never occur in the  | 
282 |  | -# training data.  | 
283 |  | -ILLEGAL_TRANSITIONS = {  | 
284 |  | -    "B_NAME_TOK": {"B_NAME_TOK", "NAME_MOD"},  | 
285 |  | -    "I_NAME_TOK": {"NAME_MOD"},  | 
286 |  | -    "NAME_MOD": {"COMMENT", "I_NAME_TOK", "PURPOSE", "QTY", "UNIT"},  | 
287 |  | -    "NAME_SEP": {"I_NAME_TOK", "PURPOSE"},  | 
288 |  | -    "NAME_VAR": {"COMMENT", "I_NAME_TOK", "NAME_MOD", "PURPOSE", "QTY", "UNIT"},  | 
289 |  | -    "PREP": {"NAME_SEP"},  | 
290 |  | -    "PURPOSE": {  | 
291 |  | -        "B_NAME_TOK",  | 
292 |  | -        "I_NAME_TOK",  | 
293 |  | -        "NAME_MOD",  | 
294 |  | -        "NAME_SEP",  | 
295 |  | -        "NAME_VAR",  | 
296 |  | -        "PREP",  | 
297 |  | -        "QTY",  | 
298 |  | -        "SIZE",  | 
299 |  | -        "UNIT",  | 
300 |  | -    },  | 
301 |  | -    "QTY": {"NAME_SEP", "PURPOSE"},  | 
302 |  | -    "SIZE": {"NAME_SEP", "PURPOSE"},  | 
303 |  | -}  | 
304 |  | - | 
305 |  | -LABELS = [  | 
306 |  | -    "B_NAME_TOK",  | 
307 |  | -    "COMMENT",  | 
308 |  | -    "I_NAME_TOK",  | 
309 |  | -    "NAME_MOD",  | 
310 |  | -    "NAME_SEP",  | 
311 |  | -    "NAME_VAR",  | 
312 |  | -    "PREP",  | 
313 |  | -    "PUNC",  | 
314 |  | -    "PURPOSE",  | 
315 |  | -    "QTY",  | 
316 |  | -    "SIZE",  | 
317 |  | -    "UNIT",  | 
318 |  | -]  | 
319 |  | - | 
320 |  | - | 
321 | 279 | def apply_label_constraints(  | 
322 | 280 |     TAGGER, labels: list[str], scores: list[float]  | 
323 | 281 | ) -> tuple[list[str], list[float]]:  | 
@@ -376,7 +334,7 @@ def apply_label_constraints(  | 
376 | 334 | 
 
  | 
377 | 335 |         else:  | 
378 | 336 |             prev_label = sequence[-1]  | 
379 |  | -            forbidden = ILLEGAL_TRANSITIONS.get(prev_label, set())  | 
 | 337 | +            forbidden = FORBIDDEN_TRANSITIONS.get(prev_label, set())  | 
380 | 338 |             if label in forbidden:  | 
381 | 339 |                 new_score, new_label = select_best_alternative_label(  | 
382 | 340 |                     TAGGER, i, forbidden  | 
 | 
0 commit comments