@@ -70,11 +70,17 @@ class TokenStats:
7070
7171
7272@dataclass  
73- class  FFTokenStats :
73+ class  TokenStatsCombinedName :
7474    """Statistics for token classification performance.""" 
7575
76-     FF : Metrics 
77-     NF : Metrics 
76+     NAME : Metrics 
77+     QTY : Metrics 
78+     UNIT : Metrics 
79+     SIZE : Metrics 
80+     COMMENT : Metrics 
81+     PURPOSE : Metrics 
82+     PREP : Metrics 
83+     PUNC : Metrics 
7884    macro_avg : Metrics 
7985    weighted_avg : Metrics 
8086    accuracy : float 
@@ -91,7 +97,7 @@ class SentenceStats:
9197class  Stats :
9298    """Statistics for token and sentence classification performance.""" 
9399
94-     token : TokenStats  |  FFTokenStats 
100+     token : TokenStats  |  TokenStatsCombinedName 
95101    sentence : SentenceStats 
96102    seed : int 
97103
@@ -161,6 +167,7 @@ def load_datasets(
161167    table : str ,
162168    datasets : list [str ],
163169    discard_other : bool  =  True ,
170+     combine_name_labels : bool  =  False ,
164171) ->  DataVectors :
165172    """Load raw data from csv files and transform into format required for training. 
166173
@@ -176,6 +183,8 @@ def load_datasets(
176183        Default is PARSER. 
177184    discard_other : bool, optional 
178185        If True, discard sentences containing tokens with OTHER label 
186+     combine_name_labels :  bool, optional 
187+         If True, combine all labels containing "NAME" into a single "NAME" label 
179188
180189    Returns 
181190    ------- 
@@ -216,6 +225,7 @@ def load_datasets(
216225                chunks ,
217226                [PreProcessor ] *  n_chunks ,
218227                [discard_other ] *  n_chunks ,
228+                 [combine_name_labels ] *  n_chunks ,
219229            )
220230        ]
221231
@@ -235,7 +245,10 @@ def load_datasets(
235245
236246
237247def  process_sentences (
238-     data : list [dict ], PreProcessor : Callable , discard_other : bool 
248+     data : list [dict ],
249+     PreProcessor : Callable ,
250+     discard_other : bool ,
251+     combine_name_labels : bool ,
239252) ->  DataVectors :
240253    """Process training sentences from database into format needed for training and 
241254    evaluation. 
@@ -247,7 +260,9 @@ def process_sentences(
247260    PreProcessor : Callable 
248261        PreProcessor class to preprocess sentences. 
249262    discard_other : bool 
250-         If True, discard sentences with OTHER label 
263+         If True, discard sentences with OTHER 
264+     combine_name_labels : bool 
265+         If True, combine all labels containing "NAME" into a single "NAME" label 
251266
252267    Returns 
253268    ------- 
@@ -278,7 +293,17 @@ def process_sentences(
278293        uids .append (entry ["id" ])
279294        features .append (p .sentence_features ())
280295        tokens .append ([t .text  for  t  in  p .tokenized_sentence ])
281-         labels .append (entry ["labels" ])
296+ 
297+         if  combine_name_labels :
298+             new_labels  =  []
299+             for  label  in  entry ["labels" ]:
300+                 if  "NAME"  in  label :
301+                     new_labels .append ("NAME" )
302+                 else :
303+                     new_labels .append (label )
304+             labels .append (new_labels )
305+         else :
306+             labels .append (entry ["labels" ])
282307
283308        # Ensure length of tokens and length of labels are the same 
284309        if  len (p .tokenized_sentence ) !=  len (entry ["labels" ]):
@@ -297,6 +322,7 @@ def evaluate(
297322    predictions : list [list [str ]],
298323    truths : list [list [str ]],
299324    seed : int ,
325+     combine_name_labels : bool ,
300326) ->  Stats :
301327    """Calculate statistics on the predicted labels for the test data. 
302328
@@ -308,6 +334,8 @@ def evaluate(
308334        True labels for each test sentence 
309335    seed : int 
310336        Seed value that produced the results 
337+     combine_name_labels : bool 
338+         If True, all NAME labels are combined into a single NAME label 
311339
312340    Returns 
313341    ------- 
@@ -338,7 +366,10 @@ def evaluate(
338366            )
339367
340368    token_stats ["accuracy" ] =  accuracy_score (flat_truths , flat_predictions )
341-     token_stats  =  TokenStats (** token_stats )
369+     if  combine_name_labels :
370+         token_stats  =  TokenStatsCombinedName (** token_stats )
371+     else :
372+         token_stats  =  TokenStats (** token_stats )
342373
343374    # Generate sentence statistics 
344375    # The only statistics that makes sense here is accuracy because there are only 
0 commit comments