@@ -271,31 +271,18 @@ def _postprocess_names(self) -> tuple[list[IngredientText], list[FoundationFood]
271271
272272        name_labels  =  [self .labels [i ] for  i  in  name_idx ]
273273        bio_groups  =  self ._group_name_labels (name_labels )
274-         constructed_names  =  self ._construct_names (bio_groups )
275- 
276-         names  =  []
277-         foundation_foods  =  set ()  # Use a set to avoid duplicates 
278-         for  group  in  constructed_names :
279-             # Convert from name_label indices to token indices 
280-             token_idx  =  [name_idx [idx ] for  idx  in  group ]
281-             ing_text  =  self ._postprocess_indices (token_idx , "NAME" )
282-             if  ing_text  is  not   None :
283-                 names .append (ing_text )
284- 
285-                 if  self .foundation_foods :
286-                     tokens  =  [self .tokens [i ] for  i  in  token_idx ]
287-                     ff  =  match_foundation_foods (tokens )
288-                     if  ff :
289-                         foundation_foods .add (ff )
290- 
291-         return  self ._deduplicate_names (names ), list (foundation_foods )
274+         constructed_names  =  self ._construct_names_from_bio_groups (bio_groups )
275+         names , foundation_foods  =  self ._convert_name_indices_to_object (
276+             name_idx , constructed_names 
277+         )
278+         return  names , foundation_foods 
292279
293280    def  _deduplicate_names (self , names : list [IngredientText ]) ->  list [IngredientText ]:
294281        """Deduplicate list of names. 
295282
296283        Where the same name text appears in multiple IngredientText objects, the 
297284        confidence values are averaged, and the minimum starting_index is kept for the 
298-         dedeuplicated  names. 
285+         deduplicated  names. 
299286
300287        Parameters 
301288        ---------- 
@@ -305,7 +292,7 @@ def _deduplicate_names(self, names: list[IngredientText]) -> list[IngredientText
305292        Returns 
306293        ------- 
307294        list[IngredientText] 
308-             Deduplicaed  list of names. 
295+             Deduplicated  list of names. 
309296        """ 
310297        name_dict  =  defaultdict (list )
311298        for  name  in  names :
@@ -381,7 +368,7 @@ def _group_name_labels(self, name_labels: list[str]) -> list[list[tuple[int, str
381368
382369        return  name_groups 
383370
384-     def  _construct_names (
371+     def  _construct_names_from_bio_groups (
385372        self , name_groups : list [list [tuple [int , str ]]]
386373    ) ->  list [list [int ]]:
387374        """Construct names from BIO groups. 
@@ -435,7 +422,7 @@ def _construct_names(
435422                    last_encountered_name_used  =  True 
436423                else :
437424                    # If we are here, then we've come across a VAR group that does not 
438-                     # preceed  a TOK group, so the model has made an error in it's 
425+                     # precede  a TOK group, so the model has made an error in it's 
439426                    # labelling. Add this VAR group anyway. 
440427                    constructed_names .append (current_group_idx )
441428
@@ -480,6 +467,87 @@ def _get_name_group_label(self, labels: tuple[str]) -> str:
480467
481468        return  "" 
482469
470+     def  _convert_name_indices_to_object (
471+         self , name_idx : list [int ], name_indices : list [list [int ]]
472+     ) ->  tuple [list [IngredientText ], list [FoundationFood ]]:
473+         """Convert grouped indices for name tokens into IngredientText objects. If 
474+         foundation foods are enabled, determine matching foundation food for each name. 
475+ 
476+         If an ingredient name ends with a token with POS tag of DT, IN or JJ, merge it 
477+         with the next name group, if there is one. This is to avoid cases in a sentence 
478+         like "5 fresh large basil leaves" where "large" is given the SIZE label, 
479+         resulting in two separate names: "fresh" and "basil leaves". Instead, we want to 
480+         return a single name: "fresh basil leaves". 
481+ 
482+         Parameters 
483+         ---------- 
484+         name_idx : list[int] 
485+             List of indices of NAME tokens. 
486+         name_indices : list[list[int]] 
487+             List of groups of indices corresponding to ingredient names. 
488+ 
489+         Returns 
490+         ------- 
491+         tuple[list[IngredientText], list[FoundationFood]] 
492+             List of deduplicated IngredientText objects and FoundationFoods objects. 
493+         """ 
494+         names  =  []
495+         foundation_foods  =  set ()  # Use a set to avoid duplicates 
496+ 
497+         # Keep track of IngredientText objects and indices to merge with next. 
498+         # We do the merge if the name ends with DT, IN, JJ part of speech tag. 
499+         merge_with_next : IngredientText  |  None  =  None 
500+         merge_with_next_idx : list [int ] |  None  =  None 
501+ 
502+         for  group  in  name_indices :
503+             # Convert from name_label indices to token indices 
504+             token_idx  =  [name_idx [idx ] for  idx  in  group ]
505+             ing_text  =  self ._postprocess_indices (token_idx , "NAME" )
506+             if  ing_text  is  None :
507+                 continue 
508+ 
509+             if  merge_with_next  and  merge_with_next_idx :
510+                 # If we need to merge the previous name, do it now. 
511+                 ing_text  =  IngredientText (
512+                     text = merge_with_next .text  +  " "  +  ing_text .text ,
513+                     confidence = (merge_with_next .confidence  +  ing_text .confidence ) /  2 ,
514+                     starting_index = min (
515+                         [merge_with_next .starting_index , ing_text .starting_index ]
516+                     ),
517+                 )
518+                 token_idx  =  [* merge_with_next_idx , * token_idx ]
519+ 
520+             if  self .pos_tags [token_idx [- 1 ]] in  {"DT" , "IN" , "JJ" }:
521+                 # Mark name for merging with next name. 
522+                 merge_with_next  =  ing_text 
523+                 merge_with_next_idx  =  token_idx 
524+                 # Skip to next iteration 
525+                 continue 
526+             else :
527+                 names .append (ing_text )
528+                 merge_with_next  =  None 
529+                 merge_with_next_idx  =  None 
530+ 
531+                 if  self .foundation_foods :
532+                     # Bug: token_idx is wrong here if we merged names 
533+                     tokens  =  [self .tokens [i ] for  i  in  token_idx ]
534+                     ff  =  match_foundation_foods (tokens )
535+                     if  ff :
536+                         foundation_foods .add (ff )
537+ 
538+         if  merge_with_next  and  merge_with_next_idx :
539+             # Catch any remaining IngredientText objects marked as needing to be merged 
540+             # but haven't been. 
541+             names .append (merge_with_next )
542+             if  self .foundation_foods :
543+                 # Bug: token_idx is wrong here if we merged names 
544+                 tokens  =  [self .tokens [i ] for  i  in  merge_with_next_idx ]
545+                 ff  =  match_foundation_foods (tokens )
546+                 if  ff :
547+                     foundation_foods .add (ff )
548+ 
549+         return  self ._deduplicate_names (names ), list (foundation_foods )
550+ 
483551    def  _postprocess_indices (
484552        self , label_idx : list [int ], selected_label : str 
485553    ) ->  IngredientText  |  None :
0 commit comments