Skip to content

Commit 049f595

Browse files
authored
Merge pull request #14 from wojiaodawei/feature/add-gliner-predict-params-flatner-threshold-multilabel
Now handles following GLiNER predict parameters: Flat_NER, threshold and multi_label
2 parents 283bb4a + e97b82d commit 049f595

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

src/utca/implementation/predictors/gliner_predictor/predictor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,12 @@ def invoke(self, input_data: GLiNERPredictorInput, evaluator: Evaluator) -> Dict
7070
"""
7171
if not input_data.labels:
7272
return {"output": [[]]*len(input_data.texts)}
73-
labels = set(input_data.labels)
7473
texts = input_data.texts
75-
outputs = self.model.batch_predict_entities(texts=texts, labels=labels) # type: ignore
74+
labels = set(input_data.labels)
75+
flat_ner = input_data.flat_ner
76+
threshold = input_data.threshold
77+
multi_label = input_data.multi_label
78+
outputs = self.model.batch_predict_entities(texts=texts, labels=labels, flat_ner=flat_ner, threshold=threshold, multi_label=multi_label) # type: ignore
7679
return ensure_dict(outputs)
7780

7881

src/utca/implementation/predictors/gliner_predictor/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ class GLiNERPredictorInput(IOModel):
3535
"""
3636
texts: List[str]
3737
labels: List[str]
38+
flat_ner: bool = True
3839
threshold: float = 0.5
39-
40+
multi_label: bool = False
4041

4142

4243
class GLiNERPredictorOutput(IOModel):

src/utca/implementation/tasks/text_processing/ner/gliner_task/actions.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,39 @@ class GLiNERPreprocessor(Action[Dict[str, Any], Dict[str, Any]]):
2020
2121
"chunks_starts" (List[int]): Chunks start positions. Used by postprocessor;
2222
23+
"flat_ner" (bool): Whether to use flat NER;
24+
2325
"threshold" (float): Minimal score for an entity to put into output;
26+
27+
"multi_label" (bool): Whether to allow multiple labels per input;
2428
"""
2529

2630
def __init__(
2731
self,
2832
sents_batch: int=10,
33+
flat_ner: bool=True,
2934
threshold: float=0.5,
35+
multi_label: bool=False,
3036
name: Optional[str]=None,
3137
) -> None:
3238
"""
3339
Args:
3440
sents_batch (int): Chunks size in sentences. Defaults to 10.
3541
36-
threshold (float): Minimial score to put entities into the output.
42+
flat_ner (bool): Whether to use flat NER. Defaults to True.
43+
44+
threshold (float): Minimial score to put entities into the output. Defaults to 0.5.
45+
46+
multi_label (bool): Whether to allow multiple labels per input. Defaults to False.
3747
3848
name (Optional[str], optional): Name for identification. If equals to None,
3949
class name will be used. Defaults to None.
4050
"""
4151
super().__init__(name)
42-
self.threshold = threshold
4352
self.sents_batch = sents_batch
53+
self.flat_ner = flat_ner
54+
self.threshold = threshold
55+
self.multi_label = multi_label
4456

4557

4658
def get_last_sentence_id(self, i: int, sentences_len: int) -> int:
@@ -78,15 +90,21 @@ def execute(
7890
7991
"chunks_starts" (List[int]): Chunks start positions. Used by postprocessor;
8092
93+
"flat_ner" (bool): Whether to use flat NER;
94+
8195
"threshold" (float): Minimal score for an entity to put into output;
96+
97+
"multi_label" (bool): Whether to allow multiple labels per input;
8298
"""
8399
chunks, chunks_starts = (
84100
self.chunkanize(input_data["text"])
85101
)
86102
return {
87103
"texts": chunks,
88104
"chunks_starts": chunks_starts,
105+
"flat_ner": self.flat_ner,
89106
"threshold": self.threshold,
107+
"multi_label": self.multi_label,
90108
}
91109

92110

0 commit comments

Comments
 (0)