Skip to content

Commit 58ff348

Browse files
committed
Feature: Add training option to combine all NAME_* labels into a single NAME label
1 parent 16a7e94 commit 58ff348

File tree

5 files changed

+105
-16
lines changed

5 files changed

+105
-16
lines changed

train.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def __call__(self, parser, namespace, values, option_strings):
9393
action="store_true",
9494
help="Plot confusion matrix of token labels.",
9595
)
96+
train_parser.add_argument(
97+
"--combine-name-labels",
98+
action="store_true",
99+
help="Combine labels containing 'NAME' into a single NAME label.",
100+
)
96101
train_parser.add_argument(
97102
"-v",
98103
help="Enable verbose output.",
@@ -150,6 +155,11 @@ def __call__(self, parser, namespace, values, option_strings):
150155
action="store_true",
151156
help="Plot confusion matrix of token labels.",
152157
)
158+
multiple_parser.add_argument(
159+
"--combine-name-labels",
160+
action="store_true",
161+
help="Combine labels containing 'NAME' into a single NAME label.",
162+
)
153163
multiple_parser.add_argument(
154164
"-r",
155165
"--runs",
@@ -227,6 +237,11 @@ def __call__(self, parser, namespace, values, option_strings):
227237
type=int,
228238
help="Seed value used for train/test split.",
229239
)
240+
gridsearch_parser.add_argument(
241+
"--combine-name-labels",
242+
action="store_true",
243+
help="Combine labels containing 'NAME' into a single NAME label.",
244+
)
230245
gridsearch_parser.add_argument(
231246
"--algos",
232247
default=["lbfgs"],
@@ -327,6 +342,11 @@ def __call__(self, parser, namespace, values, option_strings):
327342
default=False,
328343
help="Keep models after evaluation instead of deleting.",
329344
)
345+
featuresearch_parser.add_argument(
346+
"--combine-name-labels",
347+
action="store_true",
348+
help="Combine labels containing 'NAME' into a single NAME label.",
349+
)
330350
featuresearch_parser.add_argument(
331351
"-p",
332352
"--processes",

train/featuresearch.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def train_model_feature_search(
7474
save_model: str,
7575
seed: int,
7676
keep_model: bool,
77+
combine_name_labels: bool,
7778
) -> dict:
7879
"""Train model using selected features returning model performance statistics,
7980
model parameters and elapsed training time.
@@ -93,6 +94,8 @@ def train_model_feature_search(
9394
testing sets.
9495
keep_model : bool
9596
If True, keep model after evaluation, otherwise delete it.
97+
combine_name_labels : bool, optional
98+
If True, combine all NAME labels into a single NAME label.
9699
97100
Returns
98101
-------
@@ -157,7 +160,7 @@ def train_model_feature_search(
157160
tagger = pycrfsuite.Tagger() # type: ignore
158161
tagger.open(str(save_model_path))
159162
labels_pred = [tagger.tag(X) for X in features_test]
160-
stats = evaluate(labels_pred, truth_test, seed)
163+
stats = evaluate(labels_pred, truth_test, seed, combine_name_labels)
161164

162165
if not keep_model:
163166
save_model_path.unlink(missing_ok=True)
@@ -179,7 +182,13 @@ def feature_search(args: argparse.Namespace):
179182
args : argparse.Namespace
180183
Feature search configuration
181184
"""
182-
vectors = load_datasets(args.database, args.table, args.datasets)
185+
vectors = load_datasets(
186+
args.database,
187+
args.table,
188+
args.datasets,
189+
discard_other=True,
190+
combine_name_labels=args.combine_name_labels,
191+
)
183192

184193
if args.save_model is None:
185194
save_model = DEFAULT_MODEL_LOCATION
@@ -195,6 +204,7 @@ def feature_search(args: argparse.Namespace):
195204
save_model,
196205
args.seed,
197206
args.keep_models,
207+
args.combine_name_labels,
198208
]
199209
argument_sets.append(arguments)
200210

train/gridsearch.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,13 @@ def generate_argument_sets(args: argparse.Namespace) -> list[list]:
321321
list of lists, where each sublist is the arguments for training a model with
322322
one of the combinations of algorithms and parameters
323323
"""
324-
vectors = load_datasets(args.database, args.table, args.datasets)
324+
vectors = load_datasets(
325+
args.database,
326+
args.table,
327+
args.datasets,
328+
discard_other=True,
329+
combine_name_labels=args.combine_name_labels,
330+
)
325331

326332
# Generate list of arguments for all combinations parameters for each algorithm
327333
argument_sets = []
@@ -360,6 +366,7 @@ def generate_argument_sets(args: argparse.Namespace) -> list[list]:
360366
save_model,
361367
args.seed,
362368
args.keep_models,
369+
args.combine_name_labels,
363370
]
364371
argument_sets.append(arguments)
365372

@@ -374,6 +381,7 @@ def train_model_grid_search(
374381
save_model: str,
375382
seed: int,
376383
keep_model: bool,
384+
combine_name_labels: bool,
377385
) -> dict:
378386
"""Train model using given training algorithm and parameters,
379387
returning model performance statistics, model parameters and elapsed training time.
@@ -395,6 +403,8 @@ def train_model_grid_search(
395403
testing sets.
396404
keep_model : bool
397405
If True, keep model after evaluation, otherwise delete it.
406+
combine_name_labels : bool, optional
407+
If True, combine all NAME labels into a single NAME label.
398408
399409
Returns
400410
-------
@@ -443,7 +453,7 @@ def train_model_grid_search(
443453
tagger = pycrfsuite.Tagger() # type: ignore
444454
tagger.open(str(save_model_path))
445455
labels_pred = [tagger.tag(X) for X in features_test]
446-
stats = evaluate(labels_pred, truth_test, seed)
456+
stats = evaluate(labels_pred, truth_test, seed, combine_name_labels)
447457

448458
if not keep_model:
449459
save_model_path.unlink(missing_ok=True)

train/train_model.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def train_parser_model(
3636
detailed_results: bool,
3737
plot_confusion_matrix: bool,
3838
keep_model: bool = True,
39+
combine_name_labels: bool = False,
3940
) -> Stats:
4041
"""Train model using vectors, splitting the vectors into a train and evaluation
4142
set based on <split>. The trained model is saved to <save_model>.
@@ -59,9 +60,12 @@ def train_parser_model(
5960
the test set.
6061
plot_confusion_matrix : bool
6162
If True, plot a confusion matrix of the token labels.
62-
kee[_model : bool, optional
63+
keep_model : bool, optional
6364
If False, delete model from disk after evaluating it's performance.
6465
Default is True.
66+
combine_name_labels : bool, optional
67+
If True, combine all NAME labels into a single NAME label.
68+
Default is False
6569
6670
Returns
6771
-------
@@ -154,7 +158,7 @@ def train_parser_model(
154158
if plot_confusion_matrix:
155159
confusion_matrix(labels_pred, truth_test)
156160

157-
stats = evaluate(labels_pred, truth_test, seed)
161+
stats = evaluate(labels_pred, truth_test, seed, combine_name_labels)
158162

159163
if not keep_model:
160164
save_model.unlink(missing_ok=True)
@@ -170,7 +174,13 @@ def train_single(args: argparse.Namespace) -> None:
170174
args : argparse.Namespace
171175
Model training configuration
172176
"""
173-
vectors = load_datasets(args.database, args.table, args.datasets)
177+
vectors = load_datasets(
178+
args.database,
179+
args.table,
180+
args.datasets,
181+
discard_other=True,
182+
combine_name_labels=args.combine_name_labels,
183+
)
174184

175185
if args.save_model is None:
176186
save_model = DEFAULT_MODEL_LOCATION
@@ -186,6 +196,7 @@ def train_single(args: argparse.Namespace) -> None:
186196
args.detailed,
187197
args.confusion,
188198
keep_model=True,
199+
combine_name_labels=args.combine_name_labels,
189200
)
190201

191202
print("Sentence-level results:")
@@ -208,7 +219,13 @@ def train_multiple(args: argparse.Namespace) -> None:
208219
args : argparse.Namespace
209220
Model training configuration
210221
"""
211-
vectors = load_datasets(args.database, args.table, args.datasets)
222+
vectors = load_datasets(
223+
args.database,
224+
args.table,
225+
args.datasets,
226+
discard_other=True,
227+
combine_name_labels=args.combine_name_labels,
228+
)
212229

213230
if args.save_model is None:
214231
save_model = DEFAULT_MODEL_LOCATION
@@ -227,6 +244,7 @@ def train_multiple(args: argparse.Namespace) -> None:
227244
args.detailed,
228245
args.confusion,
229246
False, # keep_model
247+
args.combine_name_labels,
230248
)
231249
for _ in range(args.runs)
232250
]

train/training_utils.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,17 @@ class TokenStats:
7070

7171

7272
@dataclass
73-
class FFTokenStats:
73+
class TokenStatsCombinedName:
7474
"""Statistics for token classification performance."""
7575

76-
FF: Metrics
77-
NF: Metrics
76+
NAME: Metrics
77+
QTY: Metrics
78+
UNIT: Metrics
79+
SIZE: Metrics
80+
COMMENT: Metrics
81+
PURPOSE: Metrics
82+
PREP: Metrics
83+
PUNC: Metrics
7884
macro_avg: Metrics
7985
weighted_avg: Metrics
8086
accuracy: float
@@ -91,7 +97,7 @@ class SentenceStats:
9197
class Stats:
9298
"""Statistics for token and sentence classification performance."""
9399

94-
token: TokenStats | FFTokenStats
100+
token: TokenStats | TokenStatsCombinedName
95101
sentence: SentenceStats
96102
seed: int
97103

@@ -161,6 +167,7 @@ def load_datasets(
161167
table: str,
162168
datasets: list[str],
163169
discard_other: bool = True,
170+
combine_name_labels: bool = False,
164171
) -> DataVectors:
165172
"""Load raw data from csv files and transform into format required for training.
166173
@@ -176,6 +183,8 @@ def load_datasets(
176183
Default is PARSER.
177184
discard_other : bool, optional
178185
If True, discard sentences containing tokens with OTHER label
186+
combine_name_labels : bool, optional
187+
If True, combine all labels containing "NAME" into a single "NAME" label
179188
180189
Returns
181190
-------
@@ -216,6 +225,7 @@ def load_datasets(
216225
chunks,
217226
[PreProcessor] * n_chunks,
218227
[discard_other] * n_chunks,
228+
[combine_name_labels] * n_chunks,
219229
)
220230
]
221231

@@ -235,7 +245,10 @@ def load_datasets(
235245

236246

237247
def process_sentences(
238-
data: list[dict], PreProcessor: Callable, discard_other: bool
248+
data: list[dict],
249+
PreProcessor: Callable,
250+
discard_other: bool,
251+
combine_name_labels: bool,
239252
) -> DataVectors:
240253
"""Process training sentences from database into format needed for training and
241254
evaluation.
@@ -247,7 +260,9 @@ def process_sentences(
247260
PreProcessor : Callable
248261
PreProcessor class to preprocess sentences.
249262
discard_other : bool
250-
If True, discard sentences with OTHER label
263+
If True, discard sentences with OTHER
264+
combine_name_labels : bool
265+
If True, combine all labels containing "NAME" into a single "NAME" label
251266
252267
Returns
253268
-------
@@ -278,7 +293,17 @@ def process_sentences(
278293
uids.append(entry["id"])
279294
features.append(p.sentence_features())
280295
tokens.append([t.text for t in p.tokenized_sentence])
281-
labels.append(entry["labels"])
296+
297+
if combine_name_labels:
298+
new_labels = []
299+
for label in entry["labels"]:
300+
if "NAME" in label:
301+
new_labels.append("NAME")
302+
else:
303+
new_labels.append(label)
304+
labels.append(new_labels)
305+
else:
306+
labels.append(entry["labels"])
282307

283308
# Ensure length of tokens and length of labels are the same
284309
if len(p.tokenized_sentence) != len(entry["labels"]):
@@ -297,6 +322,7 @@ def evaluate(
297322
predictions: list[list[str]],
298323
truths: list[list[str]],
299324
seed: int,
325+
combine_name_labels: bool,
300326
) -> Stats:
301327
"""Calculate statistics on the predicted labels for the test data.
302328
@@ -308,6 +334,8 @@ def evaluate(
308334
True labels for each test sentence
309335
seed : int
310336
Seed value that produced the results
337+
combine_name_labels : bool
338+
If True, all NAME labels are combined into a single NAME label
311339
312340
Returns
313341
-------
@@ -338,7 +366,10 @@ def evaluate(
338366
)
339367

340368
token_stats["accuracy"] = accuracy_score(flat_truths, flat_predictions)
341-
token_stats = TokenStats(**token_stats)
369+
if combine_name_labels:
370+
token_stats = TokenStatsCombinedName(**token_stats)
371+
else:
372+
token_stats = TokenStats(**token_stats)
342373

343374
# Generate sentence statistics
344375
# The only statistics that makes sense here is accuracy because there are only

0 commit comments

Comments
 (0)