3333SQL3_DATABASE_TABLE  =  "en" 
3434SQL3_DATABASE  =  parent_dir  /  "train/data/training.sqlite3" 
3535MODEL_REQUIREMENTS  =  parent_dir  /  "requirements-dev.txt" 
36+ RESERVED_LABELLER_SEARCH_CHARS  =  r"\*\*|\~\~|\=\=" 
3637
3738# sqlite 
3839sqlite3 .register_adapter (list , json .dumps )
@@ -50,19 +51,19 @@ def error_response(status: int, message: str = ""):
5051    """Boilerplate for errors""" 
5152    if  status  ==  400 :
5253        return  jsonify (
53-             {"status" : 400 , "error" : "Sorry, bad params" , "message" : None }
54+             {"status" : 400 , "error" : "Sorry, bad params" , "message" : message }
5455        ), 400 
5556    elif  status  ==  404 :
5657        return  jsonify (
57-             {"status" : 404 , "error" : "Sorry, resource not found" , "message" : None }
58+             {"status" : 404 , "error" : "Sorry, resource not found" , "message" : message }
5859        ), 404 
5960    elif  status  ==  500 :
6061        return  jsonify (
6162            {"status" : 404 , "error" : "Sorry, api failed" , "message" : message }
6263        ), 500 
6364    else :
6465        return  jsonify (
65-             {"status" : status , "error" : "Sorry, something failed" , "message" : None }
66+             {"status" : status , "error" : "Sorry, something failed" , "message" : message }
6667        ), 500 
6768
6869
@@ -111,13 +112,13 @@ def parser():
111112
112113        try :
113114            sentence  =  data .get ("sentence" , "" )
114-             discard_isolated_stop_words  =  data .get ("discard_isolated_stop_words" , False )
115-             expect_name_in_output  =  data .get ("expect_name_in_output" , False )
115+             discard_isolated_stop_words  =  data .get ("discard_isolated_stop_words" , True )
116+             expect_name_in_output  =  data .get ("expect_name_in_output" , True )
116117            string_units  =  data .get ("string_units" , False )
117118            imperial_units  =  data .get ("imperial_units" , False )
118-             foundation_foods  =  data .get ("foundation_foods" , False )
119+             foundation_foods  =  data .get ("foundation_foods" , True )
119120            optimistic_cache_reset  =  data .get ("optimistic_cache_reset" , False )
120-             separate_names  =  data .get ("separate_names" , False )
121+             separate_names  =  data .get ("separate_names" , True )
121122
122123            if  optimistic_cache_reset :
123124                load_parser_model .cache_clear ()
@@ -384,6 +385,12 @@ def labeller_bulk_upload():
384385        return  error_response (status = 404 )
385386
386387
388+ def  is_valid_dotnum_range (s : str ) ->  bool :
389+     """Checks a str against the format "{digit}..{digit}""" 
390+ 
391+     return  bool (re .fullmatch (r"^\d*\.?\d*(?<!\.)\.\.(?!\.)\d*\.?\d*$" , s ))
392+ 
393+ 
387394@app .route ("/labeller/search" , methods = ["POST" ]) 
388395@cross_origin () 
389396def  labeller_search ():
@@ -410,12 +417,36 @@ def labeller_search():
410417                whole_word  =  data .get ("wholeWord" , False )
411418                case_sensitive  =  data .get ("caseSensitive" , False )
412419
420+                 reserved_char_search  =  re .search (
421+                     RESERVED_LABELLER_SEARCH_CHARS , sentence 
422+                 )
423+                 reserved_char_match  =  (
424+                     reserved_char_search .group () if  reserved_char_search  else  None 
425+                 )
426+ 
427+                 # reserve == for id search 
428+                 ids_reserved  =  []
429+                 if  reserved_char_match  in  ["==" ]:
430+                     ids_unique  =  map (str .strip , list (set (sentence [2 :].split ("," ))))
431+                     ids_actual  =  [
432+                         ix 
433+                         for  ix  in  ids_unique 
434+                         if  ix .isdigit () or  is_valid_dotnum_range (ix )
435+                     ]
436+ 
437+                     for  id  in  ids_actual :
438+                         if  is_valid_dotnum_range (id ):
439+                             start , stop  =  id .split (".." )
440+                             ids_reserved .extend (range (int (start ), int (stop ) +  1 ))
441+                         elif  id .isdigit ():
442+                             ids_reserved .append (int (id ))
443+ 
413444                # preprocess for correct token comparison later 
414445                sentence_preprocessed  =  PreProcessor (sentence ).sentence 
415446                # reserve ** or ~~ for wildcard, treat as empty string 
416447                sentence_cleansed  =  (
417448                    " " 
418-                     if  re . search ( r"\*\*|~~ " , sentence_preprocessed ) 
449+                     if  reserved_char_match   in  [ "** " , "~~" ] 
419450                    else  sentence_preprocessed 
420451                )
421452
@@ -456,8 +487,10 @@ def labeller_search():
456487                            if  label  in  labels 
457488                        ]
458489                    )
459-                     if  query .search (partial_sentence ) or  (
460-                         partial_sentence  ==  sentence_cleansed 
490+                     if  (
491+                         row ["id" ] in  ids_reserved 
492+                         or  query .search (partial_sentence )
493+                         or  partial_sentence  ==  sentence_cleansed 
461494                    ):
462495                        indices .append (row ["id" ])
463496
@@ -468,7 +501,7 @@ def labeller_search():
468501                    cursor .execute (
469502                        f""" 
470503                        SELECT * 
471-                         FROM en  
504+                         FROM { SQL3_DATABASE_TABLE }  
472505                        WHERE id IN ({ "," .join (["?" ] *  len (batch ))}  ) 
473506                        """ ,
474507                        (batch ),
@@ -488,7 +521,7 @@ def labeller_search():
488521                    cursor .execute (
489522                        f""" 
490523                        SELECT COUNT(*) 
491-                         FROM en  
524+                         FROM { SQL3_DATABASE_TABLE }  
492525                        WHERE id IN ({ "," .join (["?" ] *  len (batch ))}  ) 
493526                        """ ,
494527                        (batch ),
0 commit comments