| 
8 | 8 | import string  | 
9 | 9 | import sys  | 
10 | 10 | import traceback  | 
 | 11 | +from http import HTTPStatus  | 
11 | 12 | from importlib.metadata import PackageNotFoundError, distribution  | 
12 | 13 | from pathlib import Path  | 
13 | 14 | 
 
  | 
 | 
27 | 28 | from ingredient_parser.en._loaders import load_parser_model  | 
28 | 29 | from ingredient_parser.en.preprocess import PreProcessor  | 
29 | 30 | 
 
  | 
30 |  | -# globals  | 
 | 31 | +# globals defs  | 
31 | 32 | parent_dir = Path(__file__).parent.parent  | 
32 | 33 | NPM_BUILD_DIRECTORY = "build"  | 
33 | 34 | SQL3_DATABASE_TABLE = "en"  | 
34 | 35 | SQL3_DATABASE = parent_dir / "train/data/training.sqlite3"  | 
35 | 36 | MODEL_REQUIREMENTS = parent_dir / "requirements-dev.txt"  | 
36 |  | -RESERVED_LABELLER_SEARCH_CHARS = r"\*\*|\~\~|\=\="  | 
 | 37 | + | 
 | 38 | + | 
 | 39 | +# global regex  | 
 | 40 | +RESERVED_LABELLER_SEARCH_CHARS = r"\*\*|\~\~|\=\="  # ** or ~~ or ==  | 
 | 41 | +RESERVED_DOTNUM_RANGE_CHARS = (  | 
 | 42 | +    r"^\d*\.?\d*(?<!\.)\.\.(?!\.)\d*\.?\d*$"  # {digit}..{digit}  | 
 | 43 | +)  | 
37 | 44 | 
 
  | 
38 | 45 | # sqlite  | 
39 | 46 | sqlite3.register_adapter(list, json.dumps)  | 
 | 
47 | 54 | load_parser_model.cache_clear()  | 
48 | 55 | 
 
  | 
49 | 56 | 
 
  | 
50 |  | -def error_response(status: int, message: str = ""):  | 
 | 57 | +# helpers  | 
 | 58 | +def is_valid_dotnum_range(s: str) -> bool:  | 
 | 59 | +    """Checks a str against the format "{digit}..{digit}"""  | 
 | 60 | + | 
 | 61 | +    return bool(re.fullmatch(RESERVED_DOTNUM_RANGE_CHARS, s))  | 
 | 62 | + | 
 | 63 | + | 
 | 64 | +def error_response(  | 
 | 65 | +    status: int,  | 
 | 66 | +    traceback: str = "",  | 
 | 67 | +):  | 
51 | 68 |     """Boilerplate for errors"""  | 
52 |  | -    if status == 400:  | 
53 |  | -        return jsonify(  | 
54 |  | -            {"status": 400, "error": "Sorry, bad params", "message": message}  | 
55 |  | -        ), 400  | 
56 |  | -    elif status == 404:  | 
 | 69 | + | 
 | 70 | +    try:  | 
57 | 71 |         return jsonify(  | 
58 |  | -            {"status": 404, "error": "Sorry, resource not found", "message": message}  | 
59 |  | -        ), 404  | 
60 |  | -    elif status == 500:  | 
 | 72 | +            {  | 
 | 73 | +                "status": HTTPStatus(status).value,  | 
 | 74 | +                "error": f"{HTTPStatus(status).name}",  | 
 | 75 | +                "traceback": traceback,  | 
 | 76 | +                "description": HTTPStatus(status).description,  | 
 | 77 | +            }  | 
 | 78 | +        ), HTTPStatus(status).value  | 
 | 79 | +    except Exception:  | 
61 | 80 |         return jsonify(  | 
62 |  | -            {"status": 404, "error": "Sorry, api failed", "message": message}  | 
63 |  | -        ), 500  | 
64 |  | -    else:  | 
65 |  | -        return jsonify(  | 
66 |  | -            {"status": status, "error": "Sorry, something failed", "message": message}  | 
 | 81 | +            {  | 
 | 82 | +                "status": 500,  | 
 | 83 | +                "error": f"{HTTPStatus.INTERNAL_SERVER_ERROR.value}",  | 
 | 84 | +                "traceback": "",  | 
 | 85 | +                "description": HTTPStatus.INTERNAL_SERVER_ERROR.description,  | 
 | 86 | +            }  | 
67 | 87 |         ), 500  | 
68 | 88 | 
 
  | 
69 | 89 | 
 
  | 
@@ -99,6 +119,7 @@ def get_all_marginals(parser_info: ParserDebugInfo) -> list[dict[str, float]]:  | 
99 | 119 |     return marginals  | 
100 | 120 | 
 
  | 
101 | 121 | 
 
  | 
 | 122 | +# routes  | 
102 | 123 | @app.route("/parser", methods=["POST"])  | 
103 | 124 | @cross_origin()  | 
104 | 125 | def parser():  | 
@@ -194,7 +215,7 @@ def parser():  | 
194 | 215 | 
 
  | 
195 | 216 |         except Exception as ex:  | 
196 | 217 |             traced = "".join(traceback.TracebackException.from_exception(ex).format())  | 
197 |  | -            return error_response(status=500, message=traced)  | 
 | 218 | +            return error_response(status=500, traceback=traced)  | 
198 | 219 | 
 
  | 
199 | 220 |     else:  | 
200 | 221 |         return error_response(status=404)  | 
@@ -235,7 +256,7 @@ def preupload():  | 
235 | 256 | 
 
  | 
236 | 257 |         except Exception as ex:  | 
237 | 258 |             traced = "".join(traceback.TracebackException.from_exception(ex).format())  | 
238 |  | -            return error_response(status=500, message=traced)  | 
 | 259 | +            return error_response(status=500, traceback=traced)  | 
239 | 260 | 
 
  | 
240 | 261 |     else:  | 
241 | 262 |         return error_response(status=404)  | 
@@ -264,7 +285,7 @@ def available_sources():  | 
264 | 285 | 
 
  | 
265 | 286 |         except Exception as ex:  | 
266 | 287 |             traced = "".join(traceback.TracebackException.from_exception(ex).format())  | 
267 |  | -            return error_response(status=500, message=traced)  | 
 | 288 | +            return error_response(status=500, traceback=traced)  | 
268 | 289 | 
 
  | 
269 | 290 |     else:  | 
270 | 291 |         return error_response(status=404)  | 
@@ -328,7 +349,7 @@ def labeller_save():  | 
328 | 349 |         except Exception as ex:  | 
329 | 350 |             traced = "".join(traceback.TracebackException.from_exception(ex).format())  | 
330 | 351 |             print(traced)  | 
331 |  | -            return error_response(status=500, message=traced)  | 
 | 352 | +            return error_response(status=500, traceback=traced)  | 
332 | 353 | 
 
  | 
333 | 354 |     else:  | 
334 | 355 |         return error_response(status=404)  | 
@@ -379,18 +400,12 @@ def labeller_bulk_upload():  | 
379 | 400 | 
 
  | 
380 | 401 |         except Exception as ex:  | 
381 | 402 |             traced = "".join(traceback.TracebackException.from_exception(ex).format())  | 
382 |  | -            return error_response(status=500, message=traced)  | 
 | 403 | +            return error_response(status=500, traceback=traced)  | 
383 | 404 | 
 
  | 
384 | 405 |     else:  | 
385 | 406 |         return error_response(status=404)  | 
386 | 407 | 
 
  | 
387 | 408 | 
 
  | 
388 |  | -def is_valid_dotnum_range(s: str) -> bool:  | 
389 |  | -    """Checks a str against the format "{digit}..{digit}"""  | 
390 |  | - | 
391 |  | -    return bool(re.fullmatch(r"^\d*\.?\d*(?<!\.)\.\.(?!\.)\d*\.?\d*$", s))  | 
392 |  | - | 
393 |  | - | 
394 | 409 | @app.route("/labeller/search", methods=["POST"])  | 
395 | 410 | @cross_origin()  | 
396 | 411 | def labeller_search():  | 
@@ -427,11 +442,10 @@ def labeller_search():  | 
427 | 442 |                 # reserve == for id search  | 
428 | 443 |                 ids_reserved = []  | 
429 | 444 |                 if reserved_char_match in ["=="]:  | 
430 |  | -                    ids_unique = map(str.strip, list(set(sentence[2:].split(","))))  | 
431 | 445 |                     ids_actual = [  | 
432 |  | -                        ix  | 
433 |  | -                        for ix in ids_unique  | 
434 |  | -                        if ix.isdigit() or is_valid_dotnum_range(ix)  | 
 | 446 | +                        ix.strip()  | 
 | 447 | +                        for ix in set(sentence[2:].split(","))  | 
 | 448 | +                        if ix.strip().isdigit() or is_valid_dotnum_range(ix.strip())  | 
435 | 449 |                     ]  | 
436 | 450 | 
 
  | 
437 | 451 |                     for id in ids_actual:  | 
@@ -542,7 +556,7 @@ def labeller_search():  | 
542 | 556 | 
 
  | 
543 | 557 |         except Exception as ex:  | 
544 | 558 |             traced = "".join(traceback.TracebackException.from_exception(ex).format())  | 
545 |  | -            return error_response(status=500, message=traced)  | 
 | 559 | +            return error_response(status=500, traceback=traced)  | 
546 | 560 | 
 
  | 
547 | 561 |     else:  | 
548 | 562 |         return error_response(status=404)  | 
 | 
0 commit comments