Skip to content

Commit 455aa3e

Browse files
committed
Bugfix: De-duplicate ingredient names.
1 parent da9a23e commit 455aa3e

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

ingredient_parser/en/postprocess.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,40 @@ def _postprocess_names(self) -> tuple[list[IngredientText], list[FoundationFood]
288288
if ff:
289289
foundation_foods.add(ff)
290290

291-
return names, list(foundation_foods)
291+
return self._deduplicate_names(names), list(foundation_foods)
292+
293+
def _deduplicate_names(self, names: list[IngredientText]) -> list[IngredientText]:
294+
"""Deduplicate list of names.
295+
296+
Where the same name text appears in multiple IngredientText objects, the
297+
confidence values are averaged, and the minimum starting_index is kept for the
298+
dedeuplicated names.
299+
300+
Parameters
301+
----------
302+
names : list[IngredientText]
303+
List of names.
304+
305+
Returns
306+
-------
307+
list[IngredientText]
308+
Deduplicaed list of names.
309+
"""
310+
name_dict = defaultdict(list)
311+
for name in names:
312+
name_dict[name.text].append(name)
313+
314+
deduped_names = []
315+
for text, name_objs in name_dict.items():
316+
deduped_names.append(
317+
IngredientText(
318+
text=text,
319+
confidence=mean([n.confidence for n in name_objs]),
320+
starting_index=min([n.starting_index for n in name_objs]),
321+
)
322+
)
323+
324+
return deduped_names
292325

293326
def _group_name_labels(self, name_labels: list[str]) -> list[list[tuple[int, str]]]:
294327
"""Group name labels according to name label type.

tests/postprocess/test_process_names.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,21 @@ def test_multiple_modified_variant_names(self):
112112
p = PostProcessor(sentence, tokens, pos_tags, labels, scores)
113113
names, _ = p._postprocess_names()
114114
assert names == expected
115+
116+
def test_deuplicate_ingredient_names(self):
117+
"""
118+
Test that a list containing one IngredientText objects is returned
119+
"""
120+
sentence = "1/2 cup sugar plus 1 1/2 tablespoons sugar"
121+
tokens = ["#1$2", "cup", "sugar", "plus", "1#1$2", "tablespoon", "sugar"]
122+
pos_tags = ["CD", "NN", "NN", "CC", "CD", "NN", "NN"]
123+
labels = ["QTY", "UNIT", "B_NAME_TOK", "COMMENT", "QTY", "UNIT", "B_NAME_TOK"]
124+
scores = [0.0] * len(tokens)
125+
126+
expected = [
127+
IngredientText(text="sugar", confidence=0, starting_index=2),
128+
]
129+
130+
p = PostProcessor(sentence, tokens, pos_tags, labels, scores)
131+
names, _ = p._postprocess_names()
132+
assert names == expected

0 commit comments

Comments
 (0)